o
    )i)                     @  s   d Z ddlmZ ddlmZ ddlmZ ddlmZ ddl	Z	ddl
m  mZ ddl	mZmZ ddlmZmZ g d	ZG d
d dejZG dd deZG dd deZG dd dejZdS )z
This script is modified from from torchvision to support N-D images,
by overriding the definition of convolutional layers and pooling layers.

https://github.com/pytorch/vision/blob/release/0.12/torchvision/ops/feature_pyramid_network.py
    )annotations)OrderedDict)Callable)castN)Tensornn)ConvPool)ExtraFPNBlockLastLevelMaxPoolLastLevelP6P7FeaturePyramidNetworkc                   @  s   e Zd ZdZd
ddZd	S )r
   z
    Base class for the extra block in the FPN.

    Same code as https://github.com/pytorch/vision/blob/release/0.12/torchvision/ops/feature_pyramid_network.py
    resultslist[Tensor]xnames	list[str]c                 C  s   dS )av  
        Compute extended set of results of the FPN and their names.

        Args:
            results: the result of the FPN
            x: the original feature maps
            names: the names for each one of the original feature maps

        Returns:
            - the extended set of results of the FPN
            - the extended set of names for the results
        N selfr   r   r   r   r   o/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/blocks/feature_pyramid_network.pyforwardK   s   zExtraFPNBlock.forwardN)r   r   r   r   r   r   )__name__
__module____qualname____doc__r   r   r   r   r   r
   D   s    r
   c                      s,   e Zd ZdZd fddZdddZ  ZS )r   z
    Applies a max_pool2d or max_pool3d on top of the last feature map. Serves as an ``extra_blocks``
    in :class:`~monai.networks.blocks.feature_pyramid_network.FeaturePyramidNetwork` .
    spatial_dimsintc                   s,   t    ttj|f }|dddd| _d S )N      r   kernel_sizestridepadding)super__init__r	   MAXmaxpool)r   r   	pool_type	__class__r   r   r%   a   s   
zLastLevelMaxPool.__init__r   r   r   r   r   returntuple[list[Tensor], list[str]]c                 C  s&   | d | | |d  ||fS )Npool)appendr'   r   r   r   r   r   f   s   
zLastLevelMaxPool.forward)r   r   r   r   r   r   r   r   r+   r,   r   r   r   r   r%   r   __classcell__r   r   r)   r   r   [   s    r   c                      s,   e Zd ZdZd fddZdddZ  ZS )r   z
    This module is used in RetinaNet to generate extra layers, P6 and P7.
    Serves as an ``extra_blocks``
    in :class:`~monai.networks.blocks.feature_pyramid_network.FeaturePyramidNetwork` .
    r   r   in_channelsout_channelsc                   s   t    ttj|f }|||dddd| _|||dddd| _| j| jfD ]}tjj|j	dd tj
|jd q&||k| _d S )N   r   r   r    ar   )r$   r%   r   CONVp6p7r   initkaiming_uniform_weight	constant_biasuse_P5)r   r   r3   r4   	conv_typemoduler)   r   r   r%   s   s   
zLastLevelP6P7.__init__r   r   r   r   r   r+   r,   c           	      C  s^   |d |d }}| j r|n|}| |}| t|}|||g |ddg ||fS )Nr.   r9   r:   )r@   r9   r:   Freluextend)	r   r   r   r   p5c5Zx5r9   r:   r   r   r   r   }   s   
zLastLevelP6P7.forward)r   r   r3   r   r4   r   r0   r1   r   r   r)   r   r   l   s    
r   c                      sD   e Zd ZdZ	dd fd
dZdddZdddZdddZ  ZS )r   a  
    Module that adds a FPN from on top of a set of feature maps. This is based on
    `"Feature Pyramid Network for Object Detection" <https://arxiv.org/abs/1612.03144>`_.

    The feature maps are currently supposed to be in increasing depth
    order.

    The input to the model is expected to be an OrderedDict[Tensor], containing
    the feature maps on top of which the FPN will be added.

    Args:
        spatial_dims: 2D or 3D images
        in_channels_list: number of channels for each feature map that
            is passed to the module
        out_channels: number of channels of the FPN representation
        extra_blocks: if provided, extra operations will
            be performed. It is expected to take the fpn features, the original
            features and the names of the original features as input, and returns
            a new list of feature maps and their corresponding names

    Examples::

        >>> m = FeaturePyramidNetwork(2, [10, 20, 30], 5)
        >>> # get some dummy data
        >>> x = OrderedDict()
        >>> x['feat0'] = torch.rand(1, 10, 64, 64)
        >>> x['feat2'] = torch.rand(1, 20, 16, 16)
        >>> x['feat3'] = torch.rand(1, 30, 8, 8)
        >>> # compute the FPN on top of x
        >>> output = m(x)
        >>> print([(k, v.shape) for k, v in output.items()])
        >>> # returns
        >>>   [('feat0', torch.Size([1, 5, 64, 64])),
        >>>    ('feat2', torch.Size([1, 5, 16, 16])),
        >>>    ('feat3', torch.Size([1, 5, 8, 8]))]

    Nr   r   in_channels_list	list[int]r4   extra_blocksExtraFPNBlock | Nonec                   s   t    ttj|f }t | _t | _|D ]$}|dkr"td|||d}|||ddd}| j	| | j	| qttj|f }	| 
 D ] }
t|
|	rhtjjttj|
jdd tjttj|
jd qH|d urtt|tstt|| _d S )Nr   z(in_channels=0 is currently not supportedr   r5   )r#   r6   g        )r$   r%   r   r8   r   
ModuleListinner_blockslayer_blocks
ValueErrorr/   modules
isinstancer;   r<   r   torchr   r=   r>   r?   r
   AssertionErrorrJ   )r   r   rH   r4   rJ   rA   r3   Zinner_block_moduleZlayer_block_moduleZ
conv_type_mr)   r   r   r%      s*   





zFeaturePyramidNetwork.__init__r   r   idxr+   c                 C  F   t | j}|dk r||7 }|}t| jD ]\}}||kr ||}q|S )zs
        This is equivalent to self.inner_blocks[idx](x),
        but torchscript doesn't support this yet
        r   )lenrM   	enumerater   r   rU   
num_blocksoutirB   r   r   r   get_result_from_inner_blocks      
z2FeaturePyramidNetwork.get_result_from_inner_blocksc                 C  rV   )zs
        This is equivalent to self.layer_blocks[idx](x),
        but torchscript doesn't support this yet
        r   )rW   rN   rX   rY   r   r   r   get_result_from_layer_blocks   r^   z2FeaturePyramidNetwork.get_result_from_layer_blocksdict[str, Tensor]c                 C  s   t | }t | }| |d d}g }|| |d tt|d ddD ]'}| || |}|jdd }t	j
||dd}	||	 }|d| || q)| jdur_| |||\}}tt t||}
|
S )z
        Computes the FPN for a set of feature maps.

        Args:
            x: feature maps for each feature level.

        Returns:
            feature maps after FPN layers. They are ordered from highest resolution first.
        r.   r   Nnearest)sizemoder   )listkeysvaluesr]   r/   r_   rangerW   shaperC   interpolateinsertrJ   r   zip)r   r   r   x_valuesZ
last_innerr   rU   Zinner_lateralZ
feat_shapeZinner_top_downr[   r   r   r   r      s   
zFeaturePyramidNetwork.forward)N)r   r   rH   rI   r4   r   rJ   rK   )r   r   rU   r   r+   r   )r   r`   r+   r`   )	r   r   r   r   r%   r]   r_   r   r2   r   r   r)   r   r      s    +
!
r   )r   
__future__r   collectionsr   collections.abcr   typingr   rR   torch.nn.functionalr   
functionalrC   r   monai.networks.layers.factoriesr   r	   __all__Moduler
   r   r   r   r   r   r   r   <module>   s   -