o
    )i                     @  s  d dl mZ d dlZd dlmZ d dlZd dlZd dlm	Z	 d dl
m	  mZ d dlm  mZ d dlmZ d dlmZ d dlmZmZmZmZ d dlmZmZ d dlmZmZmZ ed	d
d\ZZ g dZ!G dd de	j"Z#dd Z$dd Z%d&ddZ&G dd de	j"Z'G dd de	j"Z(G dd de	j"Z)G dd de)Z*e*e)dZ+dd Z,G d d! d!e	j"Z-G d"d# d#e	j"Z.d$d% Z/dS )'    )annotationsN)Sequence)	LayerNorm)MLPBlock)
PatchEmbedUnetOutBlockUnetrBasicBlockUnetrUpBlock)DropPathtrunc_normal_)ensure_tuple_replook_up_optionoptional_importeinops	rearrange)name)
	SwinUNETRwindow_partitionwindow_reverseWindowAttentionSwinTransformerBlockPatchMergingPatchMergingV2MERGING_MODE
BasicLayerSwinTransformerc                      sj   e Zd ZdZdddddddd	d
d
d
dejdddddfd4 fd,d-Zd.d/ Zej	j
d0d1 Zd2d3 Z  ZS )5r   z
    Swin UNETR based on: "Hatamizadeh et al.,
    Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images
    <https://arxiv.org/abs/2201.01266>"
       )r   r   r   r   )               T      @r    instance        Fr   mergingin_channelsintout_channels
patch_sizedepthsSequence[int]	num_headswindow_sizeSequence[int] | intqkv_biasbool	mlp_ratiofloatfeature_size	norm_nametuple | str	drop_rateattn_drop_ratedropout_path_rate	normalize
norm_layertype[LayerNorm]
patch_normuse_checkpointspatial_dims
downsamplestr | nn.Moduleuse_v2returnNonec              	     s  t    |dvrtd|| _t| j|}t||}d|  kr(dks-td tdd|  kr:dks?td tdd|  krLdksQtd td|	d dkr[td	|| _td%i d
|d|	d|d|d|d|d|d|d|d|d|d|d|d|d|dt|trt	|t
n|d|| _t|||	dd|
dd| _t||	|	dd|
dd| _t|d|	 d|	 dd|
dd| _t|d|	 d|	 dd|
dd| _t|d |	 d |	 dd|
dd| _t|d |	 d!|	 dd|
dd"| _t||	d! |	d dd|
dd"| _t||	d |	d dd|
dd"| _t||	d |	dd|
dd"| _t||	|	dd|
dd"| _t||	|d#| _d$S d|| _t|||	dd|
dd| _t||	|	dd|
dd| _t|d|	 d|	 dd|
dd| _t|d|	 d|	 dd|
dd| _t|d |	 d |	 dd|
dd| _t|d |	 d!|	 dd|
dd"| _t||	d! |	d dd|
dd"| _t||	d |	d dd|
dd"| _t||	d |	dd|
dd"| _t||	|	dd|
dd"| _t||	|d#| _d$S )&aH  
        Args:
            in_channels: dimension of input channels.
            out_channels: dimension of output channels.
            patch_size: size of the patch token.
            feature_size: dimension of network feature size.
            depths: number of layers in each stage.
            num_heads: number of attention heads.
            window_size: local window size.
            qkv_bias: add a learnable bias to query, key, value.
            mlp_ratio: ratio of mlp hidden dim to embedding dim.
            norm_name: feature normalization type and arguments.
            drop_rate: dropout rate.
            attn_drop_rate: attention dropout rate.
            dropout_path_rate: drop path rate.
            normalize: normalize output intermediate features in each stage.
            norm_layer: normalization layer.
            patch_norm: whether to apply normalization to the patch embedding. Default is False.
            use_checkpoint: use gradient checkpointing for reduced memory usage.
            spatial_dims: number of spatial dims.
            downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
                user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
                The default is currently `"merging"` (the original version defined in v0.9.0).
            use_v2: using swinunetr_v2, which adds a residual convolution block at the beggining of each swin stage.

        Examples::

            # for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48.
            >>> net = SwinUNETR(in_channels=1, out_channels=4, feature_size=48)

            # for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage.
            >>> net = SwinUNETR(in_channels=4, out_channels=3, depths=(2,4,2,2))

            # for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
            >>> net = SwinUNETR(in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)

        )r   r   z#spatial dimension should be 2 or 3.r      z'dropout rate should be between 0 and 1.z1attention dropout rate should be between 0 and 1.z)drop path rate should be between 0 and 1.r   z'feature_size should be divisible by 12.in_chans	embed_dimr-   r)   r*   r,   r1   r/   r6   r7   drop_path_rater:   r<   r=   r>   r?   rA   r   Tr>   r&   r(   kernel_sizestrider4   	res_blockr            )r>   r&   r(   rI   upsample_kernel_sizer4   rK   )r>   r&   r(   N )super__init__
ValueErrorr)   r   r9   r   
isinstancestrr   r   swinViTr   encoder1encoder2encoder3encoder4	encoder10r	   decoder5decoder4decoder3decoder2decoder1r   out)selfr&   r(   r)   r*   r,   r-   r/   r1   r3   r4   r6   r7   r8   r9   r:   r<   r=   r>   r?   rA   Zpatch_sizes	__class__rP   `/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/swin_unetr.pyrR   4   s  
=
	







	








	

zSwinUNETR.__init__c           
      C  sN  | j jd }| j jd }| j jd }| j jd }|d }t  | j jjj	
|d  | j jjj
|d  |j D ]\}}|j||dd q<|jd url|j}	|	jj	
|d  |	jj	
|d  |	jj
|d	  |j D ]\}}|j||d
d qq|jd ur|j}	|	jj	
|d  |	jj	
|d  |	jj
|d  |j D ]\}}|j||dd q|jd ur|j}	|	jj	
|d  |	jj	
|d  |	jj
|d  |j D ]\}}|j||dd q|jd ur|j}	|	jj	
|d  |	jj	
|d  |	jj
|d  W d    d S W d    d S 1 s w   Y  d S )Nr   
state_dictzmodule.patch_embed.proj.weightzmodule.patch_embed.proj.biaslayers1)n_blocklayerz,module.layers1.0.downsample.reduction.weightz'module.layers1.0.downsample.norm.weightz%module.layers1.0.downsample.norm.biaslayers2z,module.layers2.0.downsample.reduction.weightz'module.layers2.0.downsample.norm.weightz%module.layers2.0.downsample.norm.biaslayers3z,module.layers3.0.downsample.reduction.weightz'module.layers3.0.downsample.norm.weightz%module.layers3.0.downsample.norm.biaslayers4z,module.layers4.0.downsample.reduction.weightz'module.layers4.0.downsample.norm.weightz%module.layers4.0.downsample.norm.bias)rV   rg   rj   rk   rl   torchno_gradpatch_embedprojweightcopy_biasblocksnamed_children	load_fromr?   	reductionnorm)
rb   weightsZ	layers1_0Z	layers2_0Z	layers3_0Z	layers4_0ZwstateZbnameblockdrP   rP   re   rv     sN   



!$zSwinUNETR.load_fromc                 C  s`   t |}|t | jd dk}| r.t |d d  }td| d| d| j dd S )N   r   r   zspatial dimensions z  of input image (spatial shape: z) must be divisible by z**5.)nparraypowerr)   anywheretolistrS   )rb   spatial_shapeimg_size	remainderZ
wrong_dimsrP   rP   re   _check_input_size0  s   
zSwinUNETR._check_input_sizec                 C  s   t j st j s| |jdd   | || j}| |}| 	|d }| 
|d }| |d }| |d }| ||d }| ||}	| |	|}
| |
|}| ||}| |}|S )Nr   r   rD   rL   r   )rm   jitis_scripting
is_tracingr   shaperV   r9   rW   rX   rY   rZ   r[   r\   r]   r^   r_   r`   ra   )rb   x_inZhidden_states_outZenc0Zenc1Zenc2Zenc3Zdec4dec3dec2dec1dec0ra   logitsrP   rP   re   forward;  s   

zSwinUNETR.forward)*r&   r'   r(   r'   r)   r'   r*   r+   r,   r+   r-   r.   r/   r0   r1   r2   r3   r'   r4   r5   r6   r2   r7   r2   r8   r2   r9   r0   r:   r;   r<   r0   r=   r0   r>   r'   r?   r@   rA   r0   rB   rC   )__name__
__module____qualname____doc__nnr   rR   rv   rm   r   unusedr   r   __classcell__rP   rP   rc   re   r   -   s4    
 O.

r   c           	   
   C  s  |   }t|dkrN|\}}}}}| |||d  |d ||d  |d ||d  |d |} | dddddddd d	|d |d  |d  |}|S | j\}}}}| |||d  |d ||d  |d |} | dddddd d	|d |d  |}|S )
a)  window partition operation based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer

     Args:
        x: input tensor.
        window_size: local window size.
    r|   r   rD   r   r   rL   r   r!   )sizelenviewpermute
contiguousr   )	xr-   x_shapebr{   hwcwindowsrP   rP   re   r   M  s(   



8,.r   c              
   C  s   t |dkrA|\}}}}| |||d  ||d  ||d  |d |d |d d}|dddddddd	 ||||d}|S t |dkrt|\}}}| |||d  ||d  |d |d d}|dddddd |||d}|S )
aO  window reverse operation based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer

     Args:
        windows: windows tensor.
        window_size: local window size.
        dims: dimension values.
    rL   r   rD   r   r   r|   r   r   r!   )r   r   r   r   )r   r-   dimsr   r{   r   r   r   rP   rP   re   r   o  s&   


*

,$r   c                 C  sv   t |}|durt |}tt| D ]}| | || kr*| | ||< |dur*d||< q|du r3t|S t|t|fS )aQ  Computing window size based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer

     Args:
        x_size: input size.
        window_size: local window size.
        shift_size: window shifting size.
    Nr   )listranger   tuple)x_sizer-   
shift_sizeZuse_window_sizeZuse_shift_sizeirP   rP   re   get_window_size  s   r   c                      s2   e Zd ZdZ			dd fddZdd Z  ZS )r   a  
    Window based multi-head self attention module with relative position bias based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer
    Fr$   dimr'   r,   r-   r+   r/   r0   	attn_dropr2   	proj_droprB   rC   c                   s  t    || _|| _|| _|| }|d | _tjj}t	| jdkrt
td| jd  d d| jd  d  d| jd  d  || _t| jd }	t| jd }
t| jd }|durottj|	|
|dd}n
tt|	|
|}t|d}|dddddf |dddddf  }|ddd }|dddddf  | jd d 7  < |dddddf  | jd d 7  < |dddddf  | jd d 7  < |dddddf  d| jd  d d| jd  d  9  < |dddddf  d| jd  d 9  < nt	| jdkrt
td|d  d d|d  d  || _t| jd }
t| jd }|durQttj|
|dd}n	tt|
|}t|d}|dddddf |dddddf  }|ddd }|dddddf  | jd d 7  < |dddddf  | jd d 7  < |dddddf  d| jd  d 9  < |d	}| d
| t
j||d |d| _t
|| _t
||| _t
|| _t| jdd t
jd	d| _dS )aA  
        Args:
            dim: number of feature channels.
            num_heads: number of attention heads.
            window_size: local window size.
            qkv_bias: add a learnable bias to query, key, value.
            attn_drop: attention dropout rate.
            proj_drop: dropout rate of output.
        g      r   r   r   rD   Nij)indexingr   relative_position_indexrs   g{Gz?)std)r   )rQ   rR   r   r-   r,   scalerm   meshgrid__kwdefaults__r   r   	Parameterzerosrelative_position_bias_tablearangestackflattenr   r   sumregister_bufferLinearqkvDropoutr   rp   r   r   Softmaxsoftmax)rb   r   r,   r-   r/   r   r   head_dimZ	mesh_argsZcoords_dcoords_hcoords_wcoordscoords_flattenrelative_coordsr   rc   rP   re   rR     sf   

4,(((>.&
,((,
zWindowAttention.__init__c                 C  sf  |j \}}}| |||d| j|| j ddddd}|d |d |d }}}	|| j }||dd }
| j| j	 d |d |f d ||d}|ddd
 }|
|d }
|d ur|j d }|
|| || j|||dd }
|
d| j||}
| |
}
n| |
}
| |
|	j}
|
|	 dd|||}| |}| |}|S )Nr   r   r   rD   rL   r   )r   r   reshaper,   r   r   	transposer   r   cloner   	unsqueezer   r   r   todtyperp   r   )rb   r   maskr   nr   r   qkvattnrelative_position_biasnwrP   rP   re   r     s.   .


(


zWindowAttention.forward)Fr$   r$   )r   r'   r,   r'   r-   r+   r/   r0   r   r2   r   r2   rB   rC   )r   r   r   r   rR   r   r   rP   rP   rc   re   r     s    Kr   c                      sV   e Zd ZdZddddddejdfd% fddZdd Zdd  Zd!d" Z	d#d$ Z
  ZS )&r   z
    Swin Transformer block based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer
    r"   Tr$   GELUFr   r'   r,   r-   r+   r   r1   r2   r/   r0   dropr   	drop_path	act_layerrU   r:   r;   r=   rB   rC   c                   s   t    || _|| _|| _|| _|| _|| _||| _t	|| j||||d| _
|	dkr0t|	nt | _||| _t|| }t|||
|dd| _dS )as  
        Args:
            dim: number of feature channels.
            num_heads: number of attention heads.
            window_size: local window size.
            shift_size: window shift size.
            mlp_ratio: ratio of mlp hidden dim to embedding dim.
            qkv_bias: add a learnable bias to query, key, value.
            drop: dropout rate.
            attn_drop: attention dropout rate.
            drop_path: stochastic depth rate.
            act_layer: activation layer.
            norm_layer: normalization layer.
            use_checkpoint: use gradient checkpointing for reduced memory usage.
        )r-   r,   r/   r   r   r$   swin)hidden_sizemlp_dimactdropout_ratedropout_modeN)rQ   rR   r   r,   r-   r   r1   r=   norm1r   r   r
   r   Identityr   norm2r'   Mlpmlp)rb   r   r,   r-   r   r1   r/   r   r   r   r   r:   r=   Zmlp_hidden_dimrc   rP   re   rR     s(   

	
zSwinTransformerBlock.__init__c                 C  sP  |  }| |}t|dkrq|j\}}}}}t|||f| j| j\}	}
d } }}|	d ||	d   |	d  }|	d ||	d   |	d  }|	d ||	d   |	d  }t|dd||||||f}|j\}}}}}||||g}nK|j\}}}}t||f| j| j\}	}
d }}|	d ||	d   |	d  }|	d ||	d   |	d  }t|dd||||f}|j\}}}}|||g}t	dd |
D rt|dkrt
j||
d  |
d  |
d  fdd}nt|d	krt
j||
d  |
d  fd
d}|}n|}d }t||	}| j||d}|jdg|	|f R  }t||	|}t	dd |
D rTt|dkr>t
j||
d |
d |
d fdd}nt|d	krSt
j||
d |
d fd
d}n|}t|dkr|dksl|dksl|dkr|d d d |d |d |d d f  }|S t|d	kr|dks|dkr|d d d |d |d d f  }|S )Nr|   r   rD   r   c                 s      | ]}|d kV  qdS r   NrP   .0r   rP   rP   re   	<genexpr>l      z5SwinTransformerBlock.forward_part1.<locals>.<genexpr>)rD   r   r   )shiftsr   rL   )rD   r   )r   r   c                 s  r   r   rP   r   rP   rP   re   r   y  r   )r   r   r   r   r   r-   r   Fpadr   rm   rollr   r   r   r   r   )rb   r   mask_matrixr   r   r{   r   r   r   r-   r   pad_lpad_tZpad_d0Zpad_d1pad_bpad_r_dphpwpr   Z	shifted_x	attn_maskZ	x_windowsZattn_windowsrP   rP   re   forward_part1T  s^   

* 
$*$z"SwinTransformerBlock.forward_part1c                 C  s   |  | | |S N)r   r   r   )rb   r   rP   rP   re   forward_part2  s   z"SwinTransformerBlock.forward_part2c                 C  s  d| d| d}g d}t   | jj|d ||d    | jj|d ||d    | jj|d ||d    | jj|d ||d	    | jj	j|d ||d
    | jj	j|d ||d    | jj
j|d ||d    | jj
j|d ||d    | jj|d ||d    | jj|d ||d    | jjj|d ||d    | jjj|d ||d    | jjj|d ||d    | jjj|d ||d    W d    d S 1 sw   Y  d S )Nzmodule.z
.0.blocks..)znorm1.weightz
norm1.biasz!attn.relative_position_bias_tablezattn.relative_position_indexzattn.qkv.weightzattn.qkv.biaszattn.proj.weightzattn.proj.biasznorm2.weightz
norm2.biaszmlp.fc1.weightzmlp.fc1.biaszmlp.fc2.weightzmlp.fc2.biasrf   r   rD   r   r   rL   r|   r   r!   rN   	   
      r      )rm   rn   r   rq   rr   rs   r   r   r   r   rp   r   r   linear1linear2)rb   ry   rh   ri   rootZblock_namesrP   rP   re   rv     s$   
       ""zSwinTransformerBlock.load_fromc                 C  sl   |}| j rtj| j||dd}n| ||}|| | }| j r-|tj| j|dd }|S || | }|S )NF)use_reentrant)r=   
checkpointr   r   r   )rb   r   r   shortcutrP   rP   re   r     s   zSwinTransformerBlock.forward)r   r'   r,   r'   r-   r+   r   r+   r1   r2   r/   r0   r   r2   r   r2   r   r2   r   rU   r:   r;   r=   r0   rB   rC   )r   r   r   r   r   r   rR   r   r   rv   r   r   rP   rP   rc   re   r     s    56"r   c                      s2   e Zd ZdZejdfd fd
dZdd Z  ZS )r   z
    Patch merging layer based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer
    r   r   r'   r:   r;   r>   rB   rC   c                   s|   t    || _|dkr"tjd| d| dd| _|d| | _dS |dkr<tjd| d| dd| _|d| | _dS dS )z
        Args:
            dim: number of feature channels.
            norm_layer: normalization layer.
            spatial_dims: number of spatial dims.
        r   rN   r   Fr   rL   N)rQ   rR   r   r   r   rw   rx   )rb   r   r:   r>   rc   rP   re   rR     s   
zPatchMergingV2.__init__c           	        s:     }t|dkrR|\}}}}}|d dkp"|d dkp"|d dk}|r9t ddd|d d|d d|d f t fddttdtdtdD d n?t|dkr|\}}}}|d dkpi|d dk}|r|t ddd|d d|d f t fd	dttdtdD d | 	  | 
   S )
Nr|   r   rD   r   c              	     s>   g | ]\}}} d d |d d|d d|d dd d f qS Nr   rP   )r   r   jr   r   rP   re   
<listcomp>  s   > z*PatchMergingV2.forward.<locals>.<listcomp>r   rL   c                   s4   g | ]\}} d d |d d|d dd d f qS r  rP   )r   r   r  r  rP   re   r    s   4 )r   r   r   r   rm   cat	itertoolsproductr   rx   rw   )	rb   r   r   r   r{   r   r   r   	pad_inputrP   r  re   r     s$   $(( *

zPatchMergingV2.forward)r   r'   r:   r;   r>   r'   rB   rC   	r   r   r   r   r   r   rR   r   r   rP   rP   rc   re   r     s    r   c                      s    e Zd ZdZ fddZ  ZS )r   z7The `PatchMerging` module previously defined in v0.9.0.c                   s0  |  }t|dkrt |S t|dkrtd|j d|\}}}}}|d dkp7|d dkp7|d dk}|rNt|ddd|d d|d d|d f}|d d dd ddd ddd dd d f }	|d d dd ddd ddd dd d f }
|d d dd ddd ddd dd d f }|d d dd ddd ddd dd d f }|d d dd ddd ddd dd d f }|d d dd ddd ddd dd d f }|d d dd ddd ddd dd d f }|d d dd ddd ddd dd d f }t	|	|
||||||gd}| 
|}| |}|S )	NrL   r|   zexpecting 5D x, got r  r   rD   r   r   )r   r   rQ   r   rS   r   r   r   rm   r  rx   rw   )rb   r   r   r   r{   r   r   r   r  x0x1x2x3x4x5x6x7rc   rP   re   r     s*   $(,,,,,,,,

zPatchMerging.forward)r   r   r   r   r   r   rP   rP   rc   re   r     s    r   )r%   Z	mergingv2c                 C  s  d}t | dkr| \}}}tjd|||df|d}t|d  t|d  |d  t|d  dfD ]K}t|d  t|d  |d  t|d  dfD ]/}t|d  t|d  |d  t|d  dfD ]}||dd|||ddf< |d7 }qhqMq2n]t | dkr| \}}tjd||df|d}t|d  t|d  |d  t|d  dfD ].}t|d  t|d  |d  t|d  dfD ]}||dd||ddf< |d7 }qqt||}	|	d}	|	d|	d }
|
|
dktd|
dktd	}
|
S )
ad  Computing region masks based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer

     Args:
        dims: dimension values.
        window_size: local window size.
        shift_size: shift size.
        device: device.
    r   r   rD   )deviceNr   r   g      Yr$   )	r   rm   r   slicer   squeezer   masked_fillr2   )r   r-   r   r  cntr{   r   r   Zimg_maskZmask_windowsr   rP   rP   re   compute_mask  s2   
666
66


$r"  c                      s<   e Zd ZdZddddejddfd fddZdd Z  ZS ) r   z
    Basic Swin Transformer layer in one stage based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer
    r"   Fr$   Nr   r'   depthr,   r-   r+   r   r   r1   r2   r/   r0   r   r   r:   r;   r?   nn.Module | Noner=   rB   rC   c                   s   t    |_tdd |D _tdd |D _|_	_t	 	f
ddt
|D _|_tjrO|tjd_dS dS )a  
        Args:
            dim: number of feature channels.
            depth: number of layers in each stage.
            num_heads: number of attention heads.
            window_size: local window size.
            drop_path: stochastic depth rate.
            mlp_ratio: ratio of mlp hidden dim to embedding dim.
            qkv_bias: add a learnable bias to query, key, value.
            drop: dropout rate.
            attn_drop: attention dropout rate.
            norm_layer: normalization layer.
            downsample: an optional downsampling layer at the end of the layer.
            use_checkpoint: use gradient checkpointing for reduced memory usage.
        c                 s  s    | ]}|d  V  qdS )r   NrP   r   rP   rP   re   r   \  r   z&BasicLayer.__init__.<locals>.<genexpr>c                 s  s    | ]}d V  qdS r   rP   r   rP   rP   re   r   ]  s    c                   sR   g | ]%}t j|d  dkrjnj ttr!| n	dqS )r   r   )r   r,   r-   r   r1   r/   r   r   r   r:   r=   )r   r-   no_shiftr   rT   r   r   
r   r   r   r   r1   r:   r,   r/   rb   r=   rP   re   r  a  s     z'BasicLayer.__init__.<locals>.<listcomp>)r   r:   r>   N)rQ   rR   r-   r   r   r%  r#  r=   r   
ModuleListr   rt   r?   callabler   )rb   r   r#  r,   r-   r   r1   r/   r   r   r:   r?   r=   rc   r&  re   rR   ;  s   

zBasicLayer.__init__c                 C  s  |  }t|dkr|\}}}}}t|||f| j| j\}}	t|d}tt||d  |d  }
tt||d  |d  }tt||d  |d  }t	|
||g||	|j
}| jD ]}|||}q^|||||d}| jd ury| |}t|d}|S t|dkr|\}}}}t||f| j| j\}}	t|d	}tt||d  |d  }tt||d  |d  }t	||g||	|j
}| jD ]}|||}q||||d}| jd ur| |}t|d
}|S )Nr|   zb c d h w -> b d h w cr   rD   r   r   zb d h w c -> b c d h wrL   zb c h w -> b h w czb h w c -> b c h w)r   r   r   r-   r   r   r'   r}   ceilr"  r  rt   r   r?   )rb   r   r   r   r   r{   r   r   r-   r   r   r   r   r   blkrP   rP   re   r   v  s<   









zBasicLayer.forward)r   r'   r#  r'   r,   r'   r-   r+   r   r   r1   r2   r/   r0   r   r2   r   r2   r:   r;   r?   r$  r=   r0   rB   rC   r  rP   rP   rc   re   r   3  s    ;r   c                      sP   e Zd ZdZdddddejdddddfd$ fddZd%d d!Zd&d"d#Z  Z	S )'r   z
    Swin Transformer based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer
    r"   Tr$   Fr   r%   rE   r'   rF   r-   r+   r)   r*   r,   r1   r2   r/   r0   r6   r7   rG   r:   r;   r<   r=   r>   rB   rC   c                   sP  t    t|| _|| _|| _|| _|| _t| j||| jr |nd|d| _	t
j|	d| _dd td|t|D }|| _t
 | _t
 | _t
 | _t
 | _| jrit
 | _t
 | _t
 | _t
 | _t|trst|tn|}t| jD ]}tt |d|  || || | j|t|d| t|d|d   |||	|
|||d	}|dkr| j!| n |dkr| j!| n|dkr| j!| n
|d
kr| j!| | jrt"||d|  |d|  d
dddd}|dkr| j!| qz|dkr| j!| qz|dkr| j!| qz|d
kr| j!| qzt |d| jd   | _#dS )a  
        Args:
            in_chans: dimension of input channels.
            embed_dim: number of linear projection output channels.
            window_size: local window size.
            patch_size: patch size.
            depths: number of layers in each stage.
            num_heads: number of attention heads.
            mlp_ratio: ratio of mlp hidden dim to embedding dim.
            qkv_bias: add a learnable bias to query, key, value.
            drop_rate: dropout rate.
            attn_drop_rate: attention dropout rate.
            drop_path_rate: stochastic depth rate.
            norm_layer: normalization layer.
            patch_norm: add normalization after patch embedding.
            use_checkpoint: use gradient checkpointing for reduced memory usage.
            spatial_dims: spatial dimension.
            downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
                user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
                The default is currently `"merging"` (the original version defined in v0.9.0).
            use_v2: using swinunetr_v2, which adds a residual convolution block at the beginning of each swin stage.
        N)r)   rE   rF   r:   r>   )pc                 S  s   g | ]}|  qS rP   )item)r   r   rP   rP   re   r    s    z,SwinTransformer.__init__.<locals>.<listcomp>r   r   rD   )r   r#  r,   r-   r   r1   r/   r   r   r:   r?   r=   r   r#   TrH   )$rQ   rR   r   
num_layersrF   r<   r-   r)   r   ro   r   r   pos_droprm   linspacer   rA   r'  rg   rj   rk   rl   layers1clayers2clayers3clayers4crT   rU   r   r   r   r   r'   appendr   num_features)rb   rE   rF   r-   r)   r*   r,   r1   r/   r6   r7   rG   r:   r<   r=   r>   r?   rA   ZdprZdown_sample_modi_layerri   Zlayercrc   rP   re   rR     s   
+








&

	


zSwinTransformer.__init__c                 C  sz   |r;|j }t|d }t|dkr$t|d}t||g}t|d}|S t|dkr;t|d}t||g}t|d}|S )NrD   r|   zn c d h w -> n d h w czn d h w c -> n c d h wrL   zn c h w -> n h w czn h w c -> n c h w)r   r'   r   r   r   
layer_norm)rb   r   r9   r   chrP   rP   re   proj_out  s   



zSwinTransformer.proj_outc                 C  s  |  |}| |}| ||}| jr| jd | }| jd | }| ||}| jr7| jd | }| jd | }| ||}| jrR| j	d | }| j
d | }	| |	|}
| jrm| jd |	 }	| jd |	 }| ||}||||
|gS )Nr   )ro   r.  r9  rA   r0  r   rg   r1  rj   r2  rk   r3  rl   )rb   r   r9   r  Zx0_outr  Zx1_outr  Zx2_outr  Zx3_outr  Zx4_outrP   rP   re   r     s(   

zSwinTransformer.forward) rE   r'   rF   r'   r-   r+   r)   r+   r*   r+   r,   r+   r1   r2   r/   r0   r6   r2   r7   r2   rG   r2   r:   r;   r<   r0   r=   r0   r>   r'   rB   rC   )F)T)
r   r   r   r   r   r   rR   r9  r   r   rP   rP   rc   re   r     s     
qr   c                 C  sl   | dv rdS | dd dkr4| dd dkr"d| dd  }||fS d| dd  | d	d  }||fS dS )
a  
    A filter function used to filter the pretrained weights from [1], then the weights can be loaded into MONAI SwinUNETR Model.
    This function is typically used with `monai.networks.copy_model_state`
    [1] "Valanarasu JM et al., Disruptive Autoencoders: Leveraging Low-level features for 3D Medical Image Pre-training
    <https://arxiv.org/abs/2307.16896>"

    Args:
        key: the key in the source state dict used for the update.
        value: the value in the source state dict used for the update.

    Examples::

        import torch
        from monai.apps import download_url
        from monai.networks.utils import copy_model_state
        from monai.networks.nets.swin_unetr import SwinUNETR, filter_swinunetr

        model = SwinUNETR(in_channels=1, out_channels=3, feature_size=48)
        resource = (
            "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ssl_pretrained_weights.pth"
        )
        ssl_weights_path = "./ssl_pretrained_weights.pth"
        download_url(resource, ssl_weights_path)
        ssl_weights = torch.load(ssl_weights_path, weights_only=True)["model"]

        dst_dict, loaded, not_loaded = copy_model_state(model, ssl_weights, filter_func=filter_swinunetr)

    )zencoder.mask_tokenzencoder.norm.weightzencoder.norm.biaszout.conv.conv.weightzout.conv.conv.biasNrN   zencoder.   ro   zswinViT.      rP   )keyvaluenew_keyrP   rP   re   filter_swinunetr6  s   r@  r   )0
__future__r   r  collections.abcr   numpyr}   rm   torch.nnr   torch.nn.functional
functionalr   torch.utils.checkpointutilsr
  r   monai.networks.blocksr   r   r   r   r   r	   monai.networks.layersr
   r   monai.utilsr   r   r   r   r   __all__Moduler   r   r   r   r   r   r   r   r   r"  r   r   r@  rP   rP   rP   re   <module>   s@     ""
 m '0
(d  