o
    )i<                      @  sr   d dl mZ d dlmZ d dlmZ d dlmZ d dlm	Z	m
Z
mZ d dlmZ d dlmZ G dd	 d	ejZdS )
    )annotations)SequenceN)UnetOutBlock)UnetrBasicBlockUnetrPrUpBlockUnetrUpBlock)ViT)ensure_tuple_repc                      sL   e Zd ZdZ											
		d)d* fd#d$Zd%d& Zd'd( Z  ZS )+UNETRz
    UNETR based on: "Hatamizadeh et al.,
    UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>"
                convinstanceT           Fin_channelsintout_channelsimg_sizeSequence[int] | intfeature_sizehidden_sizemlp_dim	num_heads	proj_typestr	norm_nametuple | str
conv_blockbool	res_blockdropout_ratefloatspatial_dimsqkv_bias	save_attnreturnNonec                   s  t    d|  krdkstd td|| dkr!tdd| _t||}td|| _tdd t|| jD | _|| _	d	| _
t||| j||| j||| j
||||d
| _t|||dd|	|d| _t|||d dddd|	|
|d
| _t|||d dddd|	|
|d
| _t|||d dddd|	|
|d
| _t|||d dd|	|d| _t||d |d dd|	|d| _t||d |d dd|	|d| _t||d |dd|	|d| _t|||d| _d|d ftdd t|D  | _t| j| j	g | _dS )aT  
        Args:
            in_channels: dimension of input channels.
            out_channels: dimension of output channels.
            img_size: dimension of input image.
            feature_size: dimension of network feature size. Defaults to 16.
            hidden_size: dimension of hidden layer. Defaults to 768.
            mlp_dim: dimension of feedforward layer. Defaults to 3072.
            num_heads: number of attention heads. Defaults to 12.
            proj_type: patch embedding layer type. Defaults to "conv".
            norm_name: feature normalization type and arguments. Defaults to "instance".
            conv_block: if convolutional block is used. Defaults to True.
            res_block: if residual block is used. Defaults to True.
            dropout_rate: fraction of the input units to drop. Defaults to 0.0.
            spatial_dims: number of spatial dims. Defaults to 3.
            qkv_bias: apply the bias term for the qkv linear layer in self attention block. Defaults to False.
            save_attn: to make accessible the attention in self attention block. Defaults to False.

        Examples::

            # for single channel input 4-channel output with image size of (96,96,96), feature size of 32 and batch norm
            >>> net = UNETR(in_channels=1, out_channels=4, img_size=(96,96,96), feature_size=32, norm_name='batch')

             # for single channel input 4-channel output with image size of (96,96), feature size of 32 and batch norm
            >>> net = UNETR(in_channels=1, out_channels=4, img_size=96, feature_size=32, norm_name='batch', spatial_dims=2)

            # for 4-channel input 3-channel output with image size of (128,128,128), conv position embedding and instance norm
            >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), proj_type='conv', norm_name='instance')

        r      z'dropout_rate should be between 0 and 1.z-hidden_size should be divisible by num_heads.r   r   c                 s  s    | ]	\}}|| V  qd S )N ).0Zimg_dp_dr+   r+   [/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/unetr.py	<genexpr>Z   s    z!UNETR.__init__.<locals>.<genexpr>F)r   r   
patch_sizer   r   
num_layersr   r   classificationr#   r%   r&   r'   r   )r%   r   r   kernel_sizestrider   r"      )
r%   r   r   	num_layerr3   r4   upsample_kernel_sizer   r    r"         )r%   r   r   r3   r7   r   r"   )r%   r   r   c                 s  s    | ]}|d  V  qdS )r*   Nr+   )r,   dr+   r+   r.   r/      s    N)super__init__
ValueErrorr1   r	   r0   tuplezipZ	feat_sizer   r2   r   vitr   encoder1r   encoder2encoder3encoder4r   decoder5decoder4decoder3decoder2r   outrange	proj_axeslistproj_view_shape)selfr   r   r   r   r   r   r   r   r   r    r"   r#   r%   r&   r'   	__class__r+   r.   r<      s   
1
					$zUNETR.__init__c                 C  s0   | dg| j }||}|| j }|S )Nr   )sizerM   viewpermuterK   
contiguous)rN   xnew_viewr+   r+   r.   	proj_feat   s   
zUNETR.proj_featc                 C  s   |  |\}}| |}|d }| | |}|d }| | |}|d }	| | |	}
| |}| ||
}| ||}| ||}| 	||}| 
|S )Nr      	   )r@   rA   rB   rW   rC   rD   rE   rF   rG   rH   rI   )rN   x_inrU   hidden_states_outenc1x2enc2x3enc3x4Zenc4dec4dec3dec2dec1rI   r+   r+   r.   forward   s   


zUNETR.forward)r   r   r   r   r   r   TTr   r   FF) r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r    r!   r"   r!   r#   r$   r%   r   r&   r!   r'   r!   r(   r)   )__name__
__module____qualname____doc__r<   rW   rf   __classcell__r+   r+   rO   r.   r
      s$    
 $r
   )
__future__r   collections.abcr   torch.nnnnZ#monai.networks.blocks.dynunet_blockr   Z!monai.networks.blocks.unetr_blockr   r   r   Zmonai.networks.nets.vitr   monai.utilsr	   Moduler
   r+   r+   r+   r.   <module>   s   