o
    ,i3                     @  s   d dl mZ d dlmZ d dlZd dlmZ d dlmZm	Z	m
Z
mZ g dZG dd dejZ				
		ddddZeZdd  Z Z ZZdd  Z Z ZZdd  Z Z ZZdd  Z Z ZZ dS )    )annotations)SequenceN)MedNeXtBlockMedNeXtDownBlockMedNeXtOutBlockMedNeXtUpBlock)MedNeXtMedNeXtSmallMedNeXtBaseMedNeXtMediumMedNeXtLargeMedNextMedNextSMedNeXtSMedNextSmallMedNextBMedNeXtBMedNextBaseMedNextMMedNeXtMMedNextMediumMedNextLMedNeXtLMedNextLargec                      sL   e Zd ZdZ																d&d' fddZd(d$d%Z  ZS ))r   ak  
    MedNeXt model class from paper: https://arxiv.org/pdf/2303.09975

    Args:
        spatial_dims: spatial dimension of the input data. Defaults to 3.
        init_filters: number of output channels for initial convolution layer. Defaults to 32.
        in_channels: number of input channels for the network. Defaults to 1.
        out_channels: number of output channels for the network. Defaults to 2.
        encoder_expansion_ratio: expansion ratio for encoder blocks. Defaults to 2.
        decoder_expansion_ratio: expansion ratio for decoder blocks. Defaults to 2.
        bottleneck_expansion_ratio: expansion ratio for bottleneck blocks. Defaults to 2.
        kernel_size: kernel size for convolutions. Defaults to 7.
        deep_supervision: whether to use deep supervision. Defaults to False.
        use_residual_connection: whether to use residual connections in standard, down and up blocks. Defaults to False.
        blocks_down: number of blocks in each encoder stage. Defaults to [2, 2, 2, 2].
        blocks_bottleneck: number of blocks in bottleneck stage. Defaults to 2.
        blocks_up: number of blocks in each decoder stage. Defaults to [2, 2, 2, 2].
        norm_type: type of normalization layer. Defaults to 'group'.
        global_resp_norm: whether to use Global Response Normalization. Defaults to False. Refer: https://arxiv.org/abs/2301.00808
                    Fr   r   r   r   groupspatial_dimsintinit_filtersin_channelsout_channelsencoder_expansion_ratioSequence[int] | intdecoder_expansion_ratiobottleneck_expansion_ratiokernel_sizedeep_supervisionbooluse_residual_connectionblocks_downSequence[int]blocks_bottleneck	blocks_up	norm_typestrglobal_resp_normc                   sH  t    |	| _|dv sJ d| d| ttr%gt  ttr1gt dkr8tjntj}||	dd| _	g }g }t
 D ]6\}|tj	
fddt|D   |t	d	  	d	d    
d
 qKt|| _t|| _tj 	
fddt|D  | _g }g }t
D ]@\}|t	d	t   	d	t d    
d |tj	
f	ddt|D   qt|| _t|| _t	d| _|	r"	fddtdtd D }|  t|| _dS dS )aG  
        Initialize the MedNeXt model.

        This method sets up the architecture of the model, including:
        - Stem convolution
        - Encoder stages and downsampling blocks
        - Bottleneck blocks
        - Decoder stages and upsampling blocks
        - Output blocks for deep supervision (if enabled)
        )r   r   z"`spatial_dims` can only be 2 or 3.d2dr   )r*   c                   s8   g | ]}t d   d     dqS r   r$   r%   expansion_ratior*   r-   r2   dimr4   )r   .0_)enc_kernel_sizer&   r4   ir#   r2   spatial_dims_strr-    ]/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/mednext.py
<listcomp>x   s    

z$MedNeXt.__init__.<locals>.<listcomp>r   )r$   r%   r9   r*   r-   r2   r:   c                   s<   g | ]}t d t   d t   dqS r7   r   lenr;   )r.   r)   dec_kernel_sizer4   r#   r2   r@   r-   rA   rB   rC      s    r8   c                   sP   g | ]$}t d t  d   d t  d    dqS )r   r   r8   rD   r;   )	r1   rF   r(   r4   r?   r#   r2   r@   r-   rA   rB   rC      s    r$   	n_classesr:   c                   s"   g | ]}t  d |  dqS )r   rG   )r   )r<   r?   )r#   r%   r@   rA   rB   rC      s    N)super__init__do_ds
isinstancer"   rE   nnConv2dConv3dstem	enumerateappend
Sequentialranger   
ModuleList
enc_stagesdown_blocks
bottleneckr   	up_blocks
dec_stagesr   out_0reverse
out_blocks)selfr!   r#   r$   r%   r&   r(   r)   r*   r+   r-   r.   r0   r1   r2   r4   convrV   rW   
num_blocksrY   rZ   r]   	__class__)r.   r1   r)   rF   r(   r>   r&   r4   r?   r#   r2   r%   r@   r-   rB   rJ   E   s   




zMedNeXt.__init__xtorch.Tensorreturn%torch.Tensor | Sequence[torch.Tensor]c           	      C  s   |  |}g }t| j| jD ]\}}||}|| ||}q| |}| jr*g }tt| j| j	D ]+\}\}}| jrM|t
| jk rM|| j| | ||}|||d    }||}q3| |}| jru| jru|g|ddd R S |S )a  
        Forward pass of the MedNeXt model.

        This method performs the forward pass through the model, including:
        - Stem convolution
        - Encoder stages and downsampling
        - Bottleneck blocks
        - Decoder stages and upsampling with skip connections
        - Output blocks for deep supervision (if enabled)

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor or Sequence[torch.Tensor]: Output tensor(s).
        r   N)rP   ziprV   rW   rR   rX   rK   rQ   rY   rZ   rE   r]   r[   training)	r^   rc   Zenc_outputsZ	enc_stage
down_blockZ
ds_outputsr?   up_blockZ	dec_stagerA   rA   rB   forward   s&   





zMedNeXt.forward)r   r   r   r   r   r   r   r   FFr   r   r   r    F)r!   r"   r#   r"   r$   r"   r%   r"   r&   r'   r(   r'   r)   r"   r*   r"   r+   r,   r-   r,   r.   r/   r0   r"   r1   r/   r2   r3   r4   r,   )rc   rd   re   rf   )__name__
__module____qualname____doc__rJ   rl   __classcell__rA   rA   ra   rB   r   /   s(     r   r   r   r   Fvariantr3   r!   r"   r$   r%   r*   r+   r,   re   c              
   C  s   |||||ddddd	}|   dkr tdddddddd	|S |   d
kr4tdddddddd	|S |   dkrHtdddddddd	|S |   dkr\tdddddddd	|S td|  )a  
    Factory method to create MedNeXt variants.

    Args:
        variant (str): The MedNeXt variant to create ('S', 'B', 'M', or 'L').
        spatial_dims (int): Number of spatial dimensions. Defaults to 3.
        in_channels (int): Number of input channels. Defaults to 1.
        out_channels (int): Number of output channels. Defaults to 2.
        kernel_size (int): Kernel size for convolutions. Defaults to 3.
        deep_supervision (bool): Whether to use deep supervision. Defaults to False.

    Returns:
        MedNeXt: The specified MedNeXt variant.

    Raises:
        ValueError: If an invalid variant is specified.
    Tr    Fr   )	r!   r$   r%   r*   r+   r-   r2   r4   r#   Sr   r   )r&   r(   r)   r.   r0   r1   B)r   r      ru   )ru   ru   r   r   ru   M)r   ru   ru   ru   )ru   ru   ru   r   L)r   ru      rx   )rx   rx   ru   r   rx   zInvalid MedNeXt variant: NrA   )upperr   
ValueError)rr   r!   r$   r%   r*   r+   common_argsrA   rA   rB   create_mednext  sn   			
r|   c                  K     t di | S )Nrs   )rs   r|   kwargsrA   rA   rB   <lambda>_      r   c                  K  r}   )Nrt   )rt   r~   r   rA   rA   rB   r   `  r   c                  K  r}   )Nrv   )rv   r~   r   rA   rA   rB   r   a  r   c                  K  r}   )Nrw   )rw   r~   r   rA   rA   rB   r   b  r   )r   r   r   r   F)rr   r3   r!   r"   r$   r"   r%   r"   r*   r"   r+   r,   re   r   )!
__future__r   collections.abcr   torchtorch.nnrM   Z#monai.networks.blocks.mednext_blockr   r   r   r   __all__Moduler   r|   r   r   r   r   r	   r   r   r   r
   r   r   r   r   r   r   r   r   rA   rA   rA   rB   <module>   s&    aQ