U
    Ph"                     @  s   d dl mZ d dlmZ d dlZd dlmZ d dlmZm	Z	 d dl
mZ d dlmZ ddgZd	d
dddd
ddddddddddddddddddddfZG dd dejZG dd dejZdS )    )annotations)SequenceN)ADNConvolution)
ChannelPad)ChannelMatchingHighResBlock
HighResNetconv_0      )name
n_featureskernel_sizeZres_1r   r   )r   r   kernelsrepeatZres_2    Zres_3@   conv_1P      conv_2)r   r   c                      sd   e Zd Zdddddifdddifdejfd	d	d	d
ddddddd
 fddZdddddZ  ZS )r   r   r   batchaffineTreluinplaceFintzSequence[int]zSequence[int] | intztuple | strboolChannelMatching | strNone)
spatial_dimsin_channelsout_channelsr   dilation	norm_type	acti_typebiaschannel_matchingreturnc
                   s   t    t||||	d| _t }
|| }}|D ]<}|
td||||d |
t||||||dd |}q2tj	|
 | _
dS )aT  
        Args:
            spatial_dims: number of spatial dimensions of the input image.
            in_channels: number of input channels.
            out_channels: number of output channels.
            kernels: each integer k in `kernels` corresponds to a convolution layer with kernel size k.
            dilation: spacing between kernel elements.
            norm_type: feature normalization type and arguments.
                Defaults to ``("batch", {"affine": True})``.
            acti_type: {``"relu"``, ``"prelu"``, ``"relu6"``}
                Non-linear activation using ReLU or PReLU. Defaults to ``"relu"``.
            bias: whether to have a bias term in convolution blocks. Defaults to False.
                According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,
                if a conv layer is directly followed by a batch norm layer, bias should be False.
            channel_matching: {``"pad"``, ``"project"``}
                Specifies handling residual branch and conv branch channel mismatches. Defaults to ``"pad"``.

                - ``"pad"``: with zero padding.
                - ``"project"``: with a trainable conv with kernel size one.

        Raises:
            ValueError: When ``channel_matching=pad`` and ``in_channels > out_channels``. Incompatible values.

        )r!   r"   r#   modeNA)orderingr"   actnormnorm_dimT)r!   r"   r#   r   r$   r'   	conv_onlyN)super__init__r   chn_padnn
ModuleListappendr   r   
Sequentiallayers)selfr!   r"   r#   r   r$   r%   r&   r'   r(   r8   _in_chns	_out_chnsr   	__class__ S/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/nets/highresnet.pyr2   (   s4    $
   
zHighResBlock.__init__torch.Tensorxr)   c                 C  s   |  |}|t| | S N)r8   torch	as_tensorr3   )r9   rB   Zx_convr>   r>   r?   forwardg   s    
zHighResBlock.forward)__name__
__module____qualname__r   PADr2   rF   __classcell__r>   r>   r<   r?   r   &   s   

&?c                      sn   e Zd ZdZddddddifdddifd	d
eejf	ddddddddddd
 fddZdddddZ  Z	S )r	   a  
    Reimplementation of highres3dnet based on
    Li et al., "On the compactness, efficiency, and representation of 3D
    convolutional networks: Brain parcellation as a pretext task", IPMI '17

    Adapted from:
    https://github.com/NifTK/NiftyNet/blob/v0.6.0/niftynet/network/highres3dnet.py
    https://github.com/fepegar/highresnet

    Args:
        spatial_dims: number of spatial dimensions of the input image.
        in_channels: number of input channels.
        out_channels: number of output channels.
        norm_type: feature normalization type and arguments.
            Defaults to ``("batch", {"affine": True})``.
        acti_type: activation type and arguments.
            Defaults to ``("relu", {"inplace": True})``.
        dropout_prob: probability of the feature map to be zeroed
            (only applies to the penultimate conv layer).
        bias: whether to have a bias term in convolution blocks. Defaults to False.
            According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,
            if a conv layer is directly followed by a batch norm layer, bias should be False.
        layer_params: specifying key parameters of each layer/block.
        channel_matching: {``"pad"``, ``"project"``}
            Specifies handling residual branch and conv branch channel mismatches. Defaults to ``"pad"``.

            - ``"pad"``: with zero padding.
            - ``"project"``: with a trainable conv with kernel size one.
    r   r   r   r   Tr   r   g        Fr   zstr | tupleztuple | str | float | Noner   zSequence[dict]r   r    )
r!   r"   r#   r%   r&   dropout_probr'   layer_paramsr(   r)   c
                   s.  t    t }
|d }||d  }}|
t||||d d|||d t|dd D ]X\}}||d  }}d| }t|d	 D ],}|
t||||d
 |||||	d	 |}qqZ|d }||d  }}|
t||||d d||||d	 |d }|}|
t||||d d||||d	 tj	|
 | _
d S )Nr   r   r   r+   )r!   r"   r#   r   adn_orderingr-   r.   r'   r      r   r   )	r!   r"   r#   r   r$   r%   r&   r'   r(   ZNAD)	r!   r"   r#   r   rN   r-   r.   r'   dropout)r1   r2   r4   r5   r6   r   	enumerateranger   r7   blocks)r9   r!   r"   r#   r%   r&   rL   r'   rM   r(   rU   paramsr:   r;   idxZ	_dilation_r<   r>   r?   r2      s    
zHighResNet.__init__r@   rA   c                 C  s   t | |S rC   )rD   rE   rU   )r9   rB   r>   r>   r?   rF      s    zHighResNet.forward)
rG   rH   rI   __doc__DEFAULT_LAYER_PARAMS_3Dr   rJ   r2   rF   rK   r>   r>   r<   r?   r	   l   s    

&V)
__future__r   collections.abcr   rD   torch.nnr4   Zmonai.networks.blocksr   r   Z"monai.networks.layers.simplelayersr   monai.utilsr   __all__rZ   Moduler   r	   r>   r>   r>   r?   <module>   s    

F