o
    )i(                     @  s|   d dl mZ d dlZd dlmZ g dZddd
dZG dd dejZG dd deZ	G dd deZ
G dd dejZdS )    )annotationsN)MedNeXtBlockMedNeXtDownBlockMedNeXtUpBlockMedNeXtOutBlock   Fspatial_dimint	transposeboolc                 C  s(   | dkr|r	t jS t jS |rt jS t jS )N   )nnConvTranspose2dConv2dConvTranspose3dConv3dr   r
    r   e/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/blocks/mednext_block.pyget_conv_layer   s   r   c                      s8   e Zd ZdZ						dd fddZdd Z  ZS )r   a  
    MedNeXtBlock class for the MedNeXt model.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        expansion_ratio (int): Expansion ratio for the block. Defaults to 4.
        kernel_size (int): Kernel size for convolutions. Defaults to 7.
        use_residual_connection (int): Whether to use residual connection. Defaults to True.
        norm_type (str): Type of normalization to use. Defaults to "group".
        dim (str): Dimension of the input. Can be "2d" or "3d". Defaults to "3d".
        global_resp_norm (bool): Whether to use global response normalization. Defaults to False.
          Tgroup3dFin_channelsr	   out_channelsexpansion_ratiokernel_sizeuse_residual_connection	norm_typestrc	                   s,  t    || _|| _t|dkrdndd}	d|dkrdnd }
|	|||d|d |d| _|dkr:tj||d	| _n|d
krQtj	|g|g|dkrJdnd  d| _|	||| dddd| _
t | _|	|| |dddd| _|| _| jrd|| f|
 }
tjt|
dd| _tjt|
dd| _d S d S )N2dr   r   r   )   r#   r   r   r   stridepaddinggroupsr   )
num_groupsnum_channelslayer)normalized_shaper   )r   r   r   r%   r&   T)requires_grad)super__init__do_resdimr   conv1r   	GroupNormnorm	LayerNormconv2GELUactconv3global_resp_norm	Parametertorchzerosglobal_resp_betaglobal_resp_gamma)selfr   r   r   r   r   r   r0   r9   convZglobal_resp_norm_param_shape	__class__r   r   r.   .   s@   


zMedNeXtBlock.__init__c                 C  s   |}|  |}| | | |}| jrD| jdkr$tj|dddd}n	tj|dddd}||jdddd	  }| j||  | j	 | }| 
|}| jrP|| }|S )
z
        Forward pass of the MedNeXtBlock.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor.
        r!   r   )T)pr0   keepdim)rC   rD   r#   )r0   rF   gư>)r1   r7   r5   r3   r9   r0   r;   meanr>   r=   r8   r/   )r?   xx1gxnxr   r   r   forwarde   s   



zMedNeXtBlock.forward)r   r   Tr   r   F)r   r	   r   r	   r   r	   r   r	   r   r	   r   r    __name__
__module____qualname____doc__r.   rM   __classcell__r   r   rA   r   r      s    7r   c                      <   e Zd ZdZ						dd fddZ fddZ  ZS )r   a  
    MedNeXtDownBlock class for downsampling in the MedNeXt model.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        expansion_ratio (int): Expansion ratio for the block. Defaults to 4.
        kernel_size (int): Kernel size for convolutions. Defaults to 7.
        use_residual_connection (bool): Whether to use residual connection. Defaults to False.
        norm_type (str): Type of normalization to use. Defaults to "group".
        dim (str): Dimension of the input. Can be "2d" or "3d". Defaults to "3d".
        global_resp_norm (bool): Whether to use global response normalization. Defaults to False.
    r   r   Fr   r   r   r	   r   r   r   r   r   r   r    r0   r9   c	           
   
     sl   t  j||||d|||d t|dkrdndd}	|| _|r'|	||ddd| _|	|||d|d |d	| _d S )
NFr   r   r0   r9   r!   r   r   r"   r#   r   r   r   r%   r$   )r-   r.   r   resample_do_resres_convr1   
r?   r   r   r   r   r   r   r0   r9   r@   rA   r   r   r.      s,   zMedNeXtDownBlock.__init__c                   s(   t  |}| jr| |}|| }|S )z
        Forward pass of the MedNeXtDownBlock.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor.
        )r-   rM   rW   rX   r?   rI   rJ   resrA   r   r   rM      s
   

zMedNeXtDownBlock.forwardr   r   Fr   r   Fr   r	   r   r	   r   r	   r   r	   r   r   r   r    r0   r    r9   r   rN   r   r   rA   r   r      s    %r   c                      rT   )r   a  
    MedNeXtUpBlock class for upsampling in the MedNeXt model.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        expansion_ratio (int): Expansion ratio for the block. Defaults to 4.
        kernel_size (int): Kernel size for convolutions. Defaults to 7.
        use_residual_connection (bool): Whether to use residual connection. Defaults to False.
        norm_type (str): Type of normalization to use. Defaults to "group".
        dim (str): Dimension of the input. Can be "2d" or "3d". Defaults to "3d".
        global_resp_norm (bool): Whether to use global response normalization. Defaults to False.
    r   r   Fr   r   r   r	   r   r   r   r   r   r   r    r0   r9   c	           
   
     st   t  j||||d|||d || _|| _t|dkrdnddd}	|r+|	||ddd	| _|	|||d|d |d
| _d S )NFrU   r!   r   r   Tr   r#   rV   r$   )r-   r.   rW   r0   r   rX   r1   rY   rA   r   r   r.      s.   zMedNeXtUpBlock.__init__c                   s   t  |}| jdkrtjj|d}ntjj|d}| jr>| |}| jdkr2tjj|d}ntjj|d}|| }|S )z
        Forward pass of the MedNeXtUpBlock.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor.
        r!   )r#   r   r#   r   )r#   r   r#   r   r#   r   )	r-   rM   r0   r;   r   
functionalpadrW   rX   rZ   rA   r   r   rM      s   



zMedNeXtUpBlock.forwardr\   r]   rN   r   r   rA   r   r      s    &r   c                      s(   e Zd ZdZ fddZdd Z  ZS )r   z
    MedNeXtOutBlock class for the output block in the MedNeXt model.

    Args:
        in_channels (int): Number of input channels.
        n_classes (int): Number of output classes.
        dim (str): Dimension of the input. Can be "2d" or "3d".
    c                   s6   t    t|dkrdnddd}|||dd| _d S )Nr!   r   r   Tr   r#   )r   )r-   r.   r   conv_out)r?   r   	n_classesr0   r@   rA   r   r   r.   %  s   
zMedNeXtOutBlock.__init__c                 C  s
   |  |S )z
        Forward pass of the MedNeXtOutBlock.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor.
        )r`   )r?   rI   r   r   r   rM   +  s   

zMedNeXtOutBlock.forwardrN   r   r   rA   r   r     s    	r   )r   F)r   r	   r
   r   )
__future__r   r;   torch.nnr   allr   Moduler   r   r   r   r   r   r   r   <module>   s   cGR