o
    ,i"                     @  s   d dl mZ d dlmZ d dlZd dlmZ d dlmZm	Z	 d dl
mZ d dlmZ ddgZd	d
dddd
ddddddddddddddddddddfZG dd dejZG dd dejZdS )    )annotations)SequenceN)ADNConvolution)
ChannelPad)ChannelMatchingHighResBlock
HighResNetconv_0      )name
n_featureskernel_sizeZres_1r   r   )r   r   kernelsrepeatZres_2    Zres_3@   conv_1P      conv_2)r   r   c                      sH   e Zd Zdddddifdddifdejfd  fddZd!ddZ  ZS )"r   r   r   batchaffineTreluinplaceFspatial_dimsintin_channelsout_channelsr   Sequence[int]dilationSequence[int] | int	norm_typetuple | str	acti_typebiasboolchannel_matchingChannelMatching | strreturnNonec
                   s   t    t||||	d| _t }
||}}|D ]}|
td||||d |
t||||||dd |}qtj	|
 | _
dS )aT  
        Args:
            spatial_dims: number of spatial dimensions of the input image.
            in_channels: number of input channels.
            out_channels: number of output channels.
            kernels: each integer k in `kernels` corresponds to a convolution layer with kernel size k.
            dilation: spacing between kernel elements.
            norm_type: feature normalization type and arguments.
                Defaults to ``("batch", {"affine": True})``.
            acti_type: {``"relu"``, ``"prelu"``, ``"relu6"``}
                Non-linear activation using ReLU or PReLU. Defaults to ``"relu"``.
            bias: whether to have a bias term in convolution blocks. Defaults to False.
                According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,
                if a conv layer is directly followed by a batch norm layer, bias should be False.
            channel_matching: {``"pad"``, ``"project"``}
                Specifies handling residual branch and conv branch channel mismatches. Defaults to ``"pad"``.

                - ``"pad"``: with zero padding.
                - ``"project"``: with a trainable conv with kernel size one.

        Raises:
            ValueError: When ``channel_matching=pad`` and ``in_channels > out_channels``. Incompatible values.

        )r   r   r    modeNA)orderingr   actnormnorm_dimT)r   r   r    r   r"   r'   	conv_onlyN)super__init__r   chn_padnn
ModuleListappendr   r   
Sequentiallayers)selfr   r   r    r   r"   r$   r&   r'   r)   r;   _in_chns	_out_chnsr   	__class__ `/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/highresnet.pyr5   (   s.   
$
zHighResBlock.__init__xtorch.Tensorc                 C  s   |  |}|t| | S N)r;   torch	as_tensorr6   )r<   rC   Zx_convrA   rA   rB   forwardg   s   
zHighResBlock.forward)r   r   r   r   r    r   r   r!   r"   r#   r$   r%   r&   r%   r'   r(   r)   r*   r+   r,   rC   rD   r+   rD   )__name__
__module____qualname__r   PADr5   rH   __classcell__rA   rA   r?   rB   r   &   s    

?c                	      sR   e Zd ZdZddddddifdddifd	d
eejf	d" fddZd#d d!Z  Z	S )$r	   a  
    Reimplementation of highres3dnet based on
    Li et al., "On the compactness, efficiency, and representation of 3D
    convolutional networks: Brain parcellation as a pretext task", IPMI '17

    Adapted from:
    https://github.com/NifTK/NiftyNet/blob/v0.6.0/niftynet/network/highres3dnet.py
    https://github.com/fepegar/highresnet

    Args:
        spatial_dims: number of spatial dimensions of the input image.
        in_channels: number of input channels.
        out_channels: number of output channels.
        norm_type: feature normalization type and arguments.
            Defaults to ``("batch", {"affine": True})``.
        acti_type: activation type and arguments.
            Defaults to ``("relu", {"inplace": True})``.
        dropout_prob: probability of the feature map to be zeroed
            (only applies to the penultimate conv layer).
        bias: whether to have a bias term in convolution blocks. Defaults to False.
            According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,
            if a conv layer is directly followed by a batch norm layer, bias should be False.
        layer_params: specifying key parameters of each layer/block.
        channel_matching: {``"pad"``, ``"project"``}
            Specifies handling residual branch and conv branch channel mismatches. Defaults to ``"pad"``.

            - ``"pad"``: with zero padding.
            - ``"project"``: with a trainable conv with kernel size one.
    r   r   r   r   Tr   r   g        Fr   r   r   r    r$   str | tupler&   dropout_probtuple | str | float | Noner'   r(   layer_paramsSequence[dict]r)   r*   r+   r,   c
                   s.  t    t }
|d }||d }}|
t||||d d|||d t|dd D ],\}}||d }}d| }t|d	 D ]}|
t||||d
 |||||	d	 |}qBq-|d }||d }}|
t||||d d||||d	 |d }|}|
t||||d d||||d	 tj	|
 | _
d S )Nr   r   r   r.   )r   r   r    r   adn_orderingr0   r1   r'   r      r   r   )	r   r   r    r   r"   r$   r&   r'   r)   ZNAD)	r   r   r    r   rT   r0   r1   r'   dropout)r4   r5   r7   r8   r9   r   	enumerateranger   r:   blocks)r<   r   r   r    r$   r&   rP   r'   rR   r)   r[   paramsr=   r>   idxZ	_dilation_r?   rA   rB   r5      s   
zHighResNet.__init__rC   rD   c                 C  s   t | |S rE   )rF   rG   r[   )r<   rC   rA   rA   rB   rH      s   zHighResNet.forward)r   r   r   r   r    r   r$   rO   r&   rO   rP   rQ   r'   r(   rR   rS   r)   r*   r+   r,   rI   )
rJ   rK   rL   __doc__DEFAULT_LAYER_PARAMS_3Dr   rM   r5   rH   rN   rA   rA   r?   rB   r	   l   s     

V)
__future__r   collections.abcr   rF   torch.nnr7   monai.networks.blocksr   r   "monai.networks.layers.simplelayersr   monai.utilsr   __all__r`   Moduler   r	   rA   rA   rA   rB   <module>   s"   

F