U
    Pha!                     @  sv   d dl mZ d dlmZ d dlmZ d dlmZ d dlm	Z	m
Z
mZ d dlmZ d dlmZmZ G dd	 d	ejZdS )
    )annotations)SequenceN)UnetOutBlock)UnetrBasicBlockUnetrPrUpBlockUnetrUpBlock)ViT)deprecated_argensure_tuple_repc                      sj   e Zd ZdZeddddddd dddddddddddddddddd fddZdd Zdd Z  ZS )!UNETRz
    UNETR based on: "Hatamizadeh et al.,
    UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>"
    	pos_embedz1.2z1.4	proj_typezplease use `proj_type` instead.)namesinceremovednew_name
msg_suffix            convinstanceT           FintzSequence[int] | intstrztuple | strboolfloatNone)in_channelsout_channelsimg_sizefeature_sizehidden_sizemlp_dim	num_headsr   r   	norm_name
conv_block	res_blockdropout_ratespatial_dimsqkv_bias	save_attnreturnc                   s  t    d|  krdks(n td|| dkr<tdd| _t||}td|| _tdd t|| jD | _|| _	d	| _
t||| j||| j||	| j
||||d
| _t|||dd|
|d| _t|||d dddd|
||d
| _t|||d dddd|
||d
| _t|||d dddd|
||d
| _t|||d dd|
|d| _t||d |d dd|
|d| _t||d |d dd|
|d| _t||d |dd|
|d| _t|||d| _d|d ftdd t|D  | _t| j| j	g | _dS )a  
        Args:
            in_channels: dimension of input channels.
            out_channels: dimension of output channels.
            img_size: dimension of input image.
            feature_size: dimension of network feature size. Defaults to 16.
            hidden_size: dimension of hidden layer. Defaults to 768.
            mlp_dim: dimension of feedforward layer. Defaults to 3072.
            num_heads: number of attention heads. Defaults to 12.
            proj_type: patch embedding layer type. Defaults to "conv".
            norm_name: feature normalization type and arguments. Defaults to "instance".
            conv_block: if convolutional block is used. Defaults to True.
            res_block: if residual block is used. Defaults to True.
            dropout_rate: fraction of the input units to drop. Defaults to 0.0.
            spatial_dims: number of spatial dims. Defaults to 3.
            qkv_bias: apply the bias term for the qkv linear layer in self attention block. Defaults to False.
            save_attn: to make accessible the attention in self attention block. Defaults to False.

        .. deprecated:: 1.4
            ``pos_embed`` is deprecated in favor of ``proj_type``.

        Examples::

            # for single channel input 4-channel output with image size of (96,96,96), feature size of 32 and batch norm
            >>> net = UNETR(in_channels=1, out_channels=4, img_size=(96,96,96), feature_size=32, norm_name='batch')

             # for single channel input 4-channel output with image size of (96,96), feature size of 32 and batch norm
            >>> net = UNETR(in_channels=1, out_channels=4, img_size=96, feature_size=32, norm_name='batch', spatial_dims=2)

            # for 4-channel input 3-channel output with image size of (128,128,128), conv position embedding and instance norm
            >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), proj_type='conv', norm_name='instance')

        r      z'dropout_rate should be between 0 and 1.z-hidden_size should be divisible by num_heads.r   r   c                 s  s   | ]\}}|| V  qd S )N ).0Zimg_dp_dr0   r0   N/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/nets/unetr.py	<genexpr>a   s     z!UNETR.__init__.<locals>.<genexpr>F)r    r"   
patch_sizer$   r%   
num_layersr&   r   classificationr*   r+   r,   r-   r   )r+   r    r!   kernel_sizestrider'   r)      )
r+   r    r!   	num_layerr8   r9   upsample_kernel_sizer'   r(   r)         )r+   r    r!   r8   r<   r'   r)   )r+   r    r!   c                 s  s   | ]}|d  V  qdS )r/   Nr0   )r1   dr0   r0   r3   r4      s     N)super__init__
ValueErrorr6   r
   r5   tuplezipZ	feat_sizer$   r7   r   vitr   encoder1r   encoder2encoder3encoder4r   decoder5decoder4decoder3decoder2r   outrange	proj_axeslistproj_view_shape)selfr    r!   r"   r#   r$   r%   r&   r   r   r'   r(   r)   r*   r+   r,   r-   	__class__r0   r3   rA      s    8

					$zUNETR.__init__c                 C  s0   | dg| j }||}|| j }|S )Nr   )sizerR   viewpermuterP   
contiguous)rS   xnew_viewr0   r0   r3   	proj_feat   s    
zUNETR.proj_featc                 C  s   |  |\}}| |}|d }| | |}|d }| | |}|d }	| | |	}
| |}| ||
}| ||}| ||}| 	||}| 
|S )Nr      	   )rE   rF   rG   r\   rH   rI   rJ   rK   rL   rM   rN   )rS   x_inrZ   hidden_states_outenc1x2enc2x3enc3x4Zenc4dec4dec3dec2dec1rN   r0   r0   r3   forward   s    

zUNETR.forward)r   r   r   r   r   r   r   TTr   r   FF)	__name__
__module____qualname____doc__r	   rA   r\   rk   __classcell__r0   r0   rT   r3   r      s2                    6 (r   )
__future__r   collections.abcr   torch.nnnnZ#monai.networks.blocks.dynunet_blockr   Z!monai.networks.blocks.unetr_blockr   r   r   Zmonai.networks.nets.vitr   monai.utilsr	   r
   Moduler   r0   r0   r0   r3   <module>   s   