U
    Ph                  
   @  s  d dl mZ d dlZd dlmZ d dlZd dlZd dlm	Z	 d dl
m	  mZ d dlm  mZ d dlmZ d dlmZ d dlmZ d dlmZmZmZmZ d dlmZmZ d d	lmZmZm Z  d d
l!m"Z" e ddd\Z#Z$ddddddddddg
Z%G dd de	j&Z'dd Z(dd Z)d(ddZ*G dd de	j&Z+G dd de	j&Z,G dd de	j&Z-G d d de-Z.e.e-d!Z/d"d# Z0G d$d de	j&Z1G d%d de	j&Z2d&d' Z3dS ))    )annotationsN)Sequence)	LayerNorm)Final)MLPBlock)
PatchEmbedUnetOutBlockUnetrBasicBlockUnetrUpBlock)DropPathtrunc_normal_)ensure_tuple_replook_up_optionoptional_import)deprecated_argeinops	rearrange)name	SwinUNETRwindow_partitionwindow_reverseWindowAttentionSwinTransformerBlockPatchMergingPatchMergingV2MERGING_MODE
BasicLayerSwinTransformerc                      s   e Zd ZU dZdZded< eddddd	d#ddddddddddddddd fddZdd Ze	j
jdd  Zd!d" Z  ZS )$r   z
    Swin UNETR based on: "Hatamizadeh et al.,
    Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images
    <https://arxiv.org/abs/2201.01266>"
       z
Final[int]
patch_sizeimg_sizez1.3z1.5zdThe img_size argument is not required anymore and checks on the input size are run during forward().)r   sinceremoved
msg_suffixr   r   r   r               r)   instance        TFr&   mergingzSequence[int] | intintSequence[int]ztuple | strfloatboolNone)r    in_channelsout_channelsdepths	num_headsfeature_size	norm_name	drop_rateattn_drop_ratedropout_path_rate	normalizeuse_checkpointspatial_dimsreturnc                   s0  t    t||}t| j|}td|}|dkr:td| | d|  krXdksbn tdd|	  krvdksn tdd|
  krdksn td|d	 dkrtd
|| _t||||||dd||	|
tj	||t
|trt|tn||d| _t|||dd|dd| _t|||dd|dd| _t|d| d| dd|dd| _t|d| d| dd|dd| _t|d| d| dd|dd| _t|d| d| dd|dd| _t||d |d dd|dd| _t||d |d dd|dd| _t||d |dd|dd| _t|||dd|dd| _t|||d| _dS )a  
        Args:
            img_size: spatial dimension of input image.
                This argument is only used for checking that the input image size is divisible by the patch size.
                The tensor passed to forward() can have a dynamic shape as long as its spatial dimensions are divisible by 2**5.
                It will be removed in an upcoming version.
            in_channels: dimension of input channels.
            out_channels: dimension of output channels.
            feature_size: dimension of network feature size.
            depths: number of layers in each stage.
            num_heads: number of attention heads.
            norm_name: feature normalization type and arguments.
            drop_rate: dropout rate.
            attn_drop_rate: attention dropout rate.
            dropout_path_rate: drop path rate.
            normalize: normalize output intermediate features in each stage.
            use_checkpoint: use gradient checkpointing for reduced memory usage.
            spatial_dims: number of spatial dims.
            downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
                user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
                The default is currently `"merging"` (the original version defined in v0.9.0).
            use_v2: using swinunetr_v2, which adds a residual convolution block at the beggining of each swin stage.

        Examples::

            # for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48.
            >>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48)

            # for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage.
            >>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2))

            # for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
            >>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)

           )r   r&   z#spatial dimension should be 2 or 3.r      z'dropout rate should be between 0 and 1.z1attention dropout rate should be between 0 and 1.z)drop path rate should be between 0 and 1.r(   z'feature_size should be divisible by 12.      @T)in_chans	embed_dimwindow_sizer   r4   r5   	mlp_ratioqkv_biasr8   r9   drop_path_rate
norm_layerr<   r=   
downsampleuse_v2r&   r=   r2   r3   kernel_sizestrider7   	res_blockr            )r=   r2   r3   rL   upsample_kernel_sizer7   rN   )r=   r2   r3   N)super__init__r   r   
ValueError_check_input_sizer;   r   nnr   
isinstancestrr   r   swinViTr	   encoder1encoder2encoder3encoder4	encoder10r
   decoder5decoder4decoder3decoder2decoder1r   out)selfr    r2   r3   r4   r5   r6   r7   r8   r9   r:   r;   r<   r=   rI   rJ   Zpatch_sizesrD   	__class__ S/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/nets/swin_unetr.pyrT   8   s    =










	

zSwinUNETR.__init__c              	   C  s  t  n | jjjj|d d  | jjjj|d d  | jjd j	
 D ]\}}|j||dd qR| jjd jjj|d d  | jjd jjj|d d  | jjd jjj|d d	  | jjd j	
 D ]\}}|j||d
d q| jjd jjj|d d  | jjd jjj|d d  | jjd jjj|d d  | jjd j	
 D ]\}}|j||dd qj| jjd jjj|d d  | jjd jjj|d d  | jjd jjj|d d  | jjd j	
 D ]\}}|j||dd q| jjd jjj|d d  | jjd jjj|d d  | jjd jjj|d d  W 5 Q R X d S )N
state_dictzmodule.patch_embed.proj.weightzmodule.patch_embed.proj.biasr   layers1)n_blocklayerz,module.layers1.0.downsample.reduction.weightz'module.layers1.0.downsample.norm.weightz%module.layers1.0.downsample.norm.biaslayers2z,module.layers2.0.downsample.reduction.weightz'module.layers2.0.downsample.norm.weightz%module.layers2.0.downsample.norm.biaslayers3z,module.layers3.0.downsample.reduction.weightz'module.layers3.0.downsample.norm.weightz%module.layers3.0.downsample.norm.biaslayers4z,module.layers4.0.downsample.reduction.weightz'module.layers4.0.downsample.norm.weightz%module.layers4.0.downsample.norm.bias)torchno_gradrZ   patch_embedprojweightcopy_biasrl   blocksnamed_children	load_fromrI   	reductionnormro   rp   rq   )rf   weightsZbnameblockri   ri   rj   r{     s^    











zSwinUNETR.load_fromc                 C  s`   t |}|t | jd dk}| r\t |d d  }td| d| d| j dd S )N   r   r   zspatial dimensions z  of input image (spatial shape: z) must be divisible by z**5.)nparraypowerr   anywheretolistrU   )rf   spatial_shaper    	remainderZ
wrong_dimsri   ri   rj   rV   7  s    
zSwinUNETR._check_input_sizec                 C  s   t j s| |jdd   | || j}| |}| |d }| 	|d }| 
|d }| |d }| ||d }| ||}	| |	|}
| |
|}| ||}| |}|S )Nr   r   r@   rO   r&   )rr   jitis_scriptingrV   shaperZ   r;   r[   r\   r]   r^   r_   r`   ra   rb   rc   rd   re   )rf   x_inZhidden_states_outZenc0Zenc1Zenc2Zenc3Zdec4dec3dec2dec1dec0re   logitsri   ri   rj   forwardB  s    


zSwinUNETR.forward)r$   r%   r)   r*   r+   r+   r+   TFr&   r,   F)__name__
__module____qualname____doc__r   __annotations__r   rT   r{   rr   r   unusedrV   r   __classcell__ri   ri   rg   rj   r   /   s4   
            0 H1

c           	   
   C  s  |   }t|dkr|\}}}}}| |||d  |d ||d  |d ||d  |d |} | dddddddd d	|d |d  |d  |}nvt|dkr| j\}}}}| |||d  |d ||d  |d |} | dddddd d	|d |d  |}|S )
a)  window partition operation based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer

     Args:
        x: input tensor.
        window_size: local window size.
    r   r   r@   r   r&   rO   r'   r?   )sizelenviewpermute
contiguousr   )	xrD   x_shapebdhwcwindowsri   ri   rj   r   T  s(    



8,.c              
   C  s   t |dkr|\}}}}| |||d  ||d  ||d  |d |d |d d}|dddddddd	 ||||d}nft |dkr|\}}}| |||d  ||d  |d |d d}|dddddd |||d}|S )
aO  window reverse operation based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer

     Args:
        windows: windows tensor.
        window_size: local window size.
        dims: dimension values.
    rO   r   r@   r   r   r   r&   r'   r?   )r   r   r   r   )r   rD   dimsr   r   r   r   r   ri   ri   rj   r   u  s$    



,
,$c                 C  sz   t |}|dk	rt |}tt| D ]0}| | || kr$| | ||< |dk	r$d||< q$|dkrft|S t|t|fS dS )aQ  Computing window size based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer

     Args:
        x_size: input size.
        window_size: local window size.
        shift_size: window shifting size.
    Nr   )listranger   tuple)x_sizerD   
shift_sizeZuse_window_sizeZuse_shift_sizeiri   ri   rj   get_window_size  s    
r   c                	      s<   e Zd ZdZddddddddd	 fd
dZdd Z  ZS )r   a  
    Window based multi-head self attention module with relative position bias based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer
    Fr+   r-   r.   r0   r/   r1   )dimr5   rD   rF   	attn_drop	proj_dropr>   c                   s  t    || _|| _|| _|| }|d | _tjj}t	| jdkr"t
td| jd  d d| jd  d  d| jd  d  || _t| jd }	t| jd }
t| jd }|dk	rttj|	|
|dd}ntt|	|
|}t|d}|dddddf |dddddf  }|ddd }|dddddf  | jd d 7  < |dddddf  | jd d 7  < |dddddf  | jd d 7  < |dddddf  d| jd  d d| jd  d  9  < |dddddf  d| jd  d 9  < nZt	| jdkr|t
td|d  d d|d  d  || _t| jd }
t| jd }|dk	rttj|
|dd}ntt|
|}t|d}|dddddf |dddddf  }|ddd }|dddddf  | jd d 7  < |dddddf  | jd d 7  < |dddddf  d| jd  d 9  < |d	}| d
| t
j||d |d| _t
|| _t
||| _t
|| _t| jdd t
jd	d| _dS )aA  
        Args:
            dim: number of feature channels.
            num_heads: number of attention heads.
            window_size: local window size.
            qkv_bias: add a learnable bias to query, key, value.
            attn_drop: attention dropout rate.
            proj_drop: dropout rate of output.
        g      r&   r   r   r@   Nij)indexingr   relative_position_indexrx   g{Gz?)std)r   )rS   rT   r   rD   r5   scalerr   meshgrid__kwdefaults__r   rW   	Parameterzerosrelative_position_bias_tablearangestackflattenr   r   sumregister_bufferLinearqkvDropoutr   ru   r   r   Softmaxsoftmax)rf   r   r5   rD   rF   r   r   head_dimZ	mesh_argsZcoords_dcoords_hcoords_wcoordscoords_flattenrelative_coordsr   rg   ri   rj   rT     sf    

4,(((>0&
,((,
zWindowAttention.__init__c                 C  sh  |j \}}}| |||d| j|| j ddddd}|d |d |d   }}}	|| j }||dd }
| j| j	 d |d |f d ||d}|ddd
 }|
|d }
|d k	r|j d }|
|| || j|||dd }
|
d| j||}
| |
}
n
| |
}
| |
|	j}
|
|	 dd|||}| |}| |}|S )Nr&   r   r   r@   rO   r   )r   r   reshaper5   r   r   	transposer   r   cloner   	unsqueezer   r   r   todtyperu   r   )rf   r   maskr   nr   r   qkvattnrelative_position_biasnwri   ri   rj   r     s2    .
  

(


zWindowAttention.forward)Fr+   r+   )r   r   r   r   rT   r   r   ri   ri   rg   rj   r     s       Kc                      sr   e Zd ZdZddddddejdfddddd	d
d	d	d	ddd
dd fddZdd Zdd Zdd Z	dd Z
  ZS )r   z
    Swin Transformer block based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer
    rA   Tr+   GELUFr-   r.   r/   r0   rY   type[LayerNorm]r1   )r   r5   rD   r   rE   rF   dropr   	drop_path	act_layerrH   r<   r>   c                   s   t    || _|| _|| _|| _|| _|| _||| _t	|| j||||d| _
|	dkr`t|	nt | _||| _t|| }t|||
|dd| _dS )as  
        Args:
            dim: number of feature channels.
            num_heads: number of attention heads.
            window_size: local window size.
            shift_size: window shift size.
            mlp_ratio: ratio of mlp hidden dim to embedding dim.
            qkv_bias: add a learnable bias to query, key, value.
            drop: dropout rate.
            attn_drop: attention dropout rate.
            drop_path: stochastic depth rate.
            act_layer: activation layer.
            norm_layer: normalization layer.
            use_checkpoint: use gradient checkpointing for reduced memory usage.
        )rD   r5   rF   r   r   r+   swin)hidden_sizemlp_dimactdropout_ratedropout_modeN)rS   rT   r   r5   rD   r   rE   r<   norm1r   r   r   rW   Identityr   norm2r-   Mlpmlp)rf   r   r5   rD   r   rE   rF   r   r   r   r   rH   r<   Zmlp_hidden_dimrg   ri   rj   rT   %  s(    

	
zSwinTransformerBlock.__init__c                 C  s^  |  }| |}t|dkr|j\}}}}}t|||f| j| j\}	}
d } }}|	d ||	d   |	d  }|	d ||	d   |	d  }|	d ||	d   |	d  }t|dd||||||f}|j\}}}}}||||g}nt|dkr|j\}}}}t||f| j| j\}	}
d }}|	d ||	d   |	d  }|	d ||	d   |	d  }t|dd||||f}|j\}}}}|||g}t	dd |
D rt|dkrt
j||
d  |
d  |
d  fdd	}n.t|dkr t
j||
d  |
d  fd
d	}|}n|}d }t||	}| j||d}|jd|	|f  }t||	|}t	dd |
D rt|dkrt
j||
d |
d |
d fdd	}n*t|dkrt
j||
d |
d fd
d	}n|}t|dkr|dks|dks|dkrZ|d d d |d |d |d d f  }nFt|dkrZ|dks6|dkrZ|d d d |d |d d f  }|S )Nr   r   r@   r   rO   c                 s  s   | ]}|d kV  qdS r   Nri   .0r   ri   ri   rj   	<genexpr>r  s     z5SwinTransformerBlock.forward_part1.<locals>.<genexpr>)r@   r   r&   )shiftsr   )r@   r   )r   r   c                 s  s   | ]}|d kV  qdS r   ri   r   ri   ri   rj   r     s     )r   )r   r   r   r   r   rD   r   Fpadr   rr   rollr   r   r   r   r   )rf   r   mask_matrixr   r   r   r   r   r   rD   r   pad_lpad_tZpad_d0Zpad_d1pad_bpad_r_dphpwpr   Z	shifted_x	attn_maskZ	x_windowsZattn_windowsri   ri   rj   forward_part1Z  s\    

* 
$,$z"SwinTransformerBlock.forward_part1c                 C  s   |  | | |S )N)r   r   r   )rf   r   ri   ri   rj   forward_part2  s    z"SwinTransformerBlock.forward_part2c                 C  s   d| d| d}dddddd	d
dddddddg}t   | jj|d ||d    | jj|d ||d    | jj|d ||d    | jj|d ||d    | jj	j|d ||d    | jj	j|d ||d    | jj
j|d ||d    | jj
j|d ||d    | jj|d ||d    | jj|d ||d    | jjj|d ||d    | jjj|d ||d    | jjj|d ||d    | jjj|d ||d     W 5 Q R X d S )!Nzmodule.z
.0.blocks..znorm1.weightz
norm1.biasz!attn.relative_position_bias_tablezattn.relative_position_indexzattn.qkv.weightzattn.qkv.biaszattn.proj.weightzattn.proj.biasznorm2.weightz
norm2.biaszmlp.fc1.weightzmlp.fc1.biaszmlp.fc2.weightzmlp.fc2.biasrk   r   r@   r   r&   rO   r   r'   r?   rQ   	   
      r(      )rr   rs   r   rv   rw   rx   r   r   r   r   ru   r   r   linear1linear2)rf   r~   rm   rn   rootZblock_namesri   ri   rj   r{     s>           zSwinTransformerBlock.load_fromc                 C  sj   |}| j r tj| j||dd}n| ||}|| | }| j rX|tj| j|dd }n|| | }|S )NF)use_reentrant)r<   
checkpointr   r   r   )rf   r   r   shortcutri   ri   rj   r     s    zSwinTransformerBlock.forward)r   r   r   r   rW   r   rT   r   r   r{   r   r   ri   ri   rg   rj   r     s   ,56"c                      s<   e Zd ZdZejdfddddd fddZd	d
 Z  ZS )r   z
    Patch merging layer based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer
    r&   r-   r   r1   )r   rH   r=   r>   c                   sv   t    || _|dkrBtjd| d| dd| _|d| | _n0|dkrrtjd| d| dd| _|d| | _dS )z
        Args:
            dim: number of feature channels.
            norm_layer: normalization layer.
            spatial_dims: number of spatial dims.
        r&   rQ   r   Fr   rO   N)rS   rT   r   rW   r   r|   r}   )rf   r   rH   r=   rg   ri   rj   rT     s    
zPatchMergingV2.__init__c           	        s<     }t|dkr|\}}}}}|d dkpD|d dkpD|d dk}|rrt ddd|d d|d d|d f t fddttdtdtdD d nt|dkr$|\}}}}|d dkp|d dk}|rt ddd|d d|d f t fd	dttdtdD d | 	  | 
   S )
Nr   r   r@   r   c              	     s>   g | ]6\}}} d d |d d|d d|d dd d f qS Nr   ri   )r   r   jr   r   ri   rj   
<listcomp>  s     z*PatchMergingV2.forward.<locals>.<listcomp>r   rO   c                   s4   g | ],\}} d d |d d|d dd d f qS r  ri   )r   r   r  r  ri   rj   r    s     )r   r   r   r   rr   cat	itertoolsproductr   r}   r|   )	rf   r   r   r   r   r   r   r   	pad_inputri   r  rj   r     s&    $(&  *

zPatchMergingV2.forward	r   r   r   r   rW   r   rT   r   r   ri   ri   rg   rj   r     s    c                      s    e Zd ZdZ fddZ  ZS )r   z7The `PatchMerging` module previously defined in v0.9.0.c                   s0  |  }t|dkr t |S t|dkr>td|j d|\}}}}}|d dkpn|d dkpn|d dk}|rt|ddd|d d|d d|d f}|d d dd ddd ddd dd d f }	|d d dd ddd ddd dd d f }
|d d dd ddd ddd dd d f }|d d dd ddd ddd dd d f }|d d dd ddd ddd dd d f }|d d dd ddd ddd dd d f }|d d dd ddd ddd dd d f }|d d dd ddd ddd dd d f }t	|	|
||||||gd}| 
|}| |}|S )	NrO   r   zexpecting 5D x, got r  r   r@   r   r   )r   r   rS   r   rU   r   r   r   rr   r  r}   r|   )rf   r   r   r   r   r   r   r   r  x0x1x2x3x4x5x6x7rg   ri   rj   r     s*    $(,,,,,,,,

zPatchMerging.forward)r   r   r   r   r   r   ri   ri   rg   rj   r     s   )r,   Z	mergingv2c                 C  s  d}t | dkr| \}}}tjd|||df|d}t|d  t|d  |d  t|d  dfD ]}t|d  t|d  |d  t|d  dfD ]^}t|d  t|d  |d  t|d  dfD ]&}||dd|||ddf< |d7 }qqqdnt | dkr| \}}tjd||df|d}t|d  t|d  |d  t|d  dfD ]`}t|d  t|d  |d  t|d  dfD ]&}||dd||ddf< |d7 }qq\t||}	|	d}	|	d|	d }
|
|
dktd|
dktd	}
|
S )
ad  Computing region masks based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer

     Args:
        dims: dimension values.
        window_size: local window size.
        shift_size: shift size.
        device: device.
    r   r&   r@   )deviceNr   r   g      Yr+   )	r   rr   r   slicer   squeezer   masked_fillr/   )r   rD   r   r  cntr   r   r   Zimg_maskZmask_windowsr   ri   ri   rj   compute_mask  s*    
66666

$r"  c                      sX   e Zd ZdZddddejddfdddddd	d
d	d	ddd
dd fddZdd Z  ZS )r   z
    Basic Swin Transformer layer in one stage based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer
    rA   Fr+   Nr-   r.   r   r/   r0   r   znn.Module | Noner1   )r   depthr5   rD   r   rE   rF   r   r   rH   rI   r<   r>   c                   s   t    |_tdd |D _tdd |D _|_	_t	 	f
ddt
|D _|_tjr|tjd_dS )a  
        Args:
            dim: number of feature channels.
            depth: number of layers in each stage.
            num_heads: number of attention heads.
            window_size: local window size.
            drop_path: stochastic depth rate.
            mlp_ratio: ratio of mlp hidden dim to embedding dim.
            qkv_bias: add a learnable bias to query, key, value.
            drop: dropout rate.
            attn_drop: attention dropout rate.
            norm_layer: normalization layer.
            downsample: an optional downsampling layer at the end of the layer.
            use_checkpoint: use gradient checkpointing for reduced memory usage.
        c                 s  s   | ]}|d  V  qdS )r   Nri   r   ri   ri   rj   r   b  s     z&BasicLayer.__init__.<locals>.<genexpr>c                 s  s   | ]
}d V  qdS r   ri   r   ri   ri   rj   r   c  s     c                   sR   g | ]J}t j|d  dkr$jnj ttrB| n	dqS )r   r   )r   r5   rD   r   rE   rF   r   r   r   rH   r<   )r   rD   no_shiftr   rX   r   r   
r   r   r   r   rE   rH   r5   rF   rf   r<   ri   rj   r  g  s   z'BasicLayer.__init__.<locals>.<listcomp>)r   rH   r=   N)rS   rT   rD   r   r   r$  r#  r<   rW   
ModuleListr   ry   rI   callabler   )rf   r   r#  r5   rD   r   rE   rF   r   r   rH   rI   r<   rg   r%  rj   rT   A  s    

zBasicLayer.__init__c                 C  s  |  }t|dkr|\}}}}}t|||f| j| j\}}	t|d}tt||d  |d  }
tt||d  |d  }tt||d  |d  }t	|
||g||	|j
}| jD ]}|||}q|||||d}| jd k	r| |}t|d}nt|dkr|\}}}}t||f| j| j\}}	t|d	}tt||d  |d  }tt||d  |d  }t	||g||	|j
}| jD ]}|||}q||||d}| jd k	r| |}t|d
}|S )Nr   zb c d h w -> b d h w cr   r@   r   r   zb d h w c -> b c d h wrO   zb c h w -> b h w czb h w c -> b c h w)r   r   r   rD   r   r   r-   r   ceilr"  r  ry   r   rI   )rf   r   r   r   r   r   r   r   rD   r   r   r   r   r   blkri   ri   rj   r   |  s:    







zBasicLayer.forwardr  ri   ri   rg   rj   r   9  s   ,;c                      sr   e Zd ZdZdddddejdddddfddd	d	d	d	d
dd
d
d
dddddd fddZdddZdddZ  Z	S )r   z
    Swin Transformer based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer
    rA   Tr+   Fr&   r,   r-   r.   r/   r0   r   r1   )rB   rC   rD   r   r4   r5   rE   rF   r8   r9   rG   rH   
patch_normr<   r=   r>   c                   sX  t    t|| _|| _|| _|| _|| _t| j||| jr@|nd|d| _	t
j|	d| _dd td|t|D }|| _t
 | _t
 | _t
 | _t
 | _| jrt
 | _t
 | _t
 | _t
 | _t|trt|tn|}t| jD ]D}tt |d|  || || | j|t|d| t|d|d   |||	|
|||d	}|dkrl| j!| nF|dkr| j!| n.|dkr| j!| n|d
kr| j!| | jrt"||d|  |d|  d
dddd}|dkr| j!| q|dkr| j!| q|dkr&| j!| q|d
kr| j!| qt |d| jd   | _#dS )a  
        Args:
            in_chans: dimension of input channels.
            embed_dim: number of linear projection output channels.
            window_size: local window size.
            patch_size: patch size.
            depths: number of layers in each stage.
            num_heads: number of attention heads.
            mlp_ratio: ratio of mlp hidden dim to embedding dim.
            qkv_bias: add a learnable bias to query, key, value.
            drop_rate: dropout rate.
            attn_drop_rate: attention dropout rate.
            drop_path_rate: stochastic depth rate.
            norm_layer: normalization layer.
            patch_norm: add normalization after patch embedding.
            use_checkpoint: use gradient checkpointing for reduced memory usage.
            spatial_dims: spatial dimension.
            downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
                user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
                The default is currently `"merging"` (the original version defined in v0.9.0).
            use_v2: using swinunetr_v2, which adds a residual convolution block at the beginning of each swin stage.
        N)r   rB   rC   rH   r=   )pc                 S  s   g | ]}|  qS ri   )item)r   r   ri   ri   rj   r    s     z,SwinTransformer.__init__.<locals>.<listcomp>r   r   r@   )r   r#  r5   rD   r   rE   rF   r   r   rH   rI   r<   r&   r*   TrK   )$rS   rT   r   
num_layersrC   r*  rD   r   r   rt   rW   r   pos_droprr   linspacer   rJ   r&  rl   ro   rp   rq   layers1clayers2clayers3clayers4crX   rY   r   r   r   r   r-   appendr	   num_features)rf   rB   rC   rD   r   r4   r5   rE   rF   r8   r9   rG   rH   r*  r<   r=   rI   rJ   ZdprZdown_sample_modi_layerrn   Zlayercrg   ri   rj   rT     s    +









&





	


zSwinTransformer.__init__c           	      C  s   |r|  }t|dkrJ|\}}}}}t|d}t||g}t|d}n:t|dkr|\}}}}t|d}t||g}t|d}|S )Nr   zn c d h w -> n d h w czn d h w c -> n c d h wrO   zn c h w -> n h w czn h w c -> n c h w)r   r   r   r   
layer_norm)	rf   r   r;   r   r   chr   r   r   ri   ri   rj   proj_out  s    


zSwinTransformer.proj_outc                 C  s  |  |}| |}| ||}| jr8| jd | }| jd | }| ||}| jrn| jd | }| jd | }| ||}| jr| j	d | }| j
d | }	| |	|}
| jr| jd |	 }	| jd |	 }| ||}||||
|gS )Nr   )rt   r.  r9  rJ   r0  r   rl   r1  ro   r2  rp   r3  rq   )rf   r   r;   r  Zx0_outr  Zx1_outr  Zx2_outr  Zx3_outr  Zx4_outri   ri   rj   r   %  s(    

zSwinTransformer.forward)F)T)
r   r   r   r   rW   r   rT   r9  r   r   ri   ri   rg   rj   r     s   2q
c                 C  sj   | dkrdS | dd dkrb| dd dkr>d| dd  }nd| dd  | d	d  }||fS dS dS )
a  
    A filter function used to filter the pretrained weights from [1], then the weights can be loaded into MONAI SwinUNETR Model.
    This function is typically used with `monai.networks.copy_model_state`
    [1] "Valanarasu JM et al., Disruptive Autoencoders: Leveraging Low-level features for 3D Medical Image Pre-training
    <https://arxiv.org/abs/2307.16896>"

    Args:
        key: the key in the source state dict used for the update.
        value: the value in the source state dict used for the update.

    Examples::

        import torch
        from monai.apps import download_url
        from monai.networks.utils import copy_model_state
        from monai.networks.nets.swin_unetr import SwinUNETR, filter_swinunetr

        model = SwinUNETR(img_size=(96, 96, 96), in_channels=1, out_channels=3, feature_size=48)
        resource = (
            "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ssl_pretrained_weights.pth"
        )
        ssl_weights_path = "./ssl_pretrained_weights.pth"
        download_url(resource, ssl_weights_path)
        ssl_weights = torch.load(ssl_weights_path)["model"]

        dst_dict, loaded, not_loaded = copy_model_state(model, ssl_weights, filter_func=filter_swinunetr)

    )zencoder.mask_tokenzencoder.norm.weightzencoder.norm.biaszout.conv.conv.weightzout.conv.conv.biasNrQ   zencoder.   rt   zswinViT.      ri   )keyvaluenew_keyri   ri   rj   filter_swinunetr<  s    r@  )N)4
__future__r   r  collections.abcr   numpyr   rr   torch.nnrW   torch.nn.functional
functionalr   torch.utils.checkpointutilsr
  r   typing_extensionsr   monai.networks.blocksr   r   r   r   r	   r
   monai.networks.layersr   r   monai.utilsr   r   r   Zmonai.utils.deprecate_utilsr   r   r   __all__Moduler   r   r   r   r   r   r   r   r   r"  r   r   r@  ri   ri   ri   rj   <module>   sV     '! 
m '0
(d  